def branch_attention(cost_volume_3d, cost_volume_h, cost_volume_v, cost_volume_45, cost_volume_135): feature = 4 * 9 k = 9 label = 9 cost1 = convbn(cost_volume_3d, 6, 3, 1, 1) cost1 = Activation('relu')(cost1) cost1 = convbn(cost1, 4, 3, 1, 1) cost1 = Activation('sigmoid')(cost1) cost_h = Lambda(lambda y: K.repeat_elements( K.expand_dims(y[:, :, :, :1], 1), 9, 1))(cost1) cost_h = Lambda(lambda y: K.repeat_elements(y, feature, 4))(cost_h) cost_v = Lambda(lambda y: K.repeat_elements( K.expand_dims(y[:, :, :, 1:2], 1), 9, 1))(cost1) cost_v = Lambda(lambda y: K.repeat_elements(y, feature, 4))(cost_v) cost_45 = Lambda(lambda y: K.repeat_elements( K.expand_dims(y[:, :, :, 2:3], 1), 9, 1))(cost1) cost_45 = Lambda(lambda y: K.repeat_elements(y, feature, 4))(cost_45) cost_135 = Lambda(lambda y: K.repeat_elements( K.expand_dims(y[:, :, :, 3:4], 1), 9, 1))(cost1) cost_135 = Lambda(lambda y: K.repeat_elements(y, feature, 4))(cost_135) return concatenate([ multiply([cost_h, cost_volume_h]), multiply([cost_v, cost_volume_v]), multiply([cost_45, cost_volume_45]), multiply([cost_135, cost_volume_135]) ], axis=4), cost1
def to_3d_135(cost_volume_135): feature = 4 * 9 channel_135 = GlobalAveragePooling3D( data_format='channels_last')(cost_volume_135) channel_135 = Lambda(lambda y: K.expand_dims( K.expand_dims(K.expand_dims(y, 1), 1), 1))(channel_135) channel_135 = Conv3D(feature / 2, 1, 1, 'same', data_format='channels_last')(channel_135) channel_135 = Activation('relu')(channel_135) channel_135 = Conv3D(3, 1, 1, 'same', data_format='channels_last')(channel_135) channel_135 = Activation('sigmoid')(channel_135) channel_135 = Lambda(lambda y: K.concatenate([ y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 1:2], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3] ], axis=-1))(channel_135) channel_135 = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 9)))( channel_135) channel_135 = Lambda(lambda y: K.repeat_elements(y, 4, -1))(channel_135) cv_135_tmp = multiply([channel_135, cost_volume_135]) cv_135_tmp = Conv3D(feature / 2, 1, 1, 'same', data_format='channels_last')(cv_135_tmp) cv_135_tmp = Activation('relu')(cv_135_tmp) cv_135_tmp = Conv3D(3, 1, 1, 'same', data_format='channels_last')(cv_135_tmp) cv_135_tmp = Activation('sigmoid')(cv_135_tmp) attention_135 = Lambda(lambda y: K.concatenate([ y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 0:1], y[:, :, :, :, 1:2], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3], y[:, :, :, :, 2:3] ], axis=-1))(cv_135_tmp) attention_135 = Lambda(lambda y: K.repeat_elements(y, 4, -1))( attention_135) cv_135_multi = multiply([attention_135, cost_volume_135]) dres3 = convbn_3d(cv_135_multi, feature, 3, 1) dres3 = Activation('relu')(dres3) dres3 = convbn_3d(cv_135_multi, feature / 2, 3, 1) dres3 = Activation('relu')(dres3) dres3 = convbn_3d(cv_135_multi, feature / 2, 3, 1) dres3 = Activation('relu')(dres3) dres3 = convbn_3d(cv_135_multi, feature / 4, 3, 1) dres3 = Activation('relu')(dres3) dres3 = convbn_3d(dres3, 1, 3, 1) cost3 = Activation('relu')(dres3) cost3 = Lambda(lambda x: K.permute_dimensions(K.squeeze(x, -1), (0, 2, 3, 1)))(cost3) return cost3, cv_135_multi
def spatial_attention(cost_volume): feature = 4 * 9 k = 9 label = 9 dres0 = convbn_3d(cost_volume, feature / 2, 3, 1) dres0 = Activation('relu')(dres0) dres0 = convbn_3d(dres0, 1, 3, 1) cost0 = Activation('relu')(dres0) cost0 = Lambda(lambda x: K.permute_dimensions(K.squeeze(x, -1), (0, 2, 3, 1)))(cost0) cost1 = convbn(cost0, label // 2, (1, k), 1, 1) cost1 = Activation('relu')(cost1) cost1 = convbn(cost1, 1, (k, 1), 1, 1) cost1 = Activation('relu')(cost1) cost2 = convbn(cost0, label // 2, (k, 1), 1, 1) cost2 = Activation('relu')(cost2) cost2 = convbn(cost2, 1, (1, k), 1, 1) cost2 = Activation('relu')(cost2) cost = add([cost1, cost2]) cost = Activation('sigmoid')(cost) cost = Lambda(lambda y: K.repeat_elements(K.expand_dims(y, 1), 9, 1))(cost) cost = Lambda(lambda y: K.repeat_elements(y, feature, 4))(cost) return multiply([cost, cost_volume])
def disparityregression(input): shape = K.shape(input) disparity_values = np.linspace(-4, 4, 9) x = K.constant(disparity_values, shape=[9]) x = K.expand_dims(K.expand_dims(K.expand_dims(x, 0), 0), 0) x = tf.tile(x, [shape[0], shape[1], shape[2], 1]) out = K.sum(multiply([input, x]), -1) return out
def channel_attention_free(cost_volume): x = GlobalAveragePooling3D()(cost_volume) x = Lambda( lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x) x = Conv3D(170, 1, 1, 'same')(x) x = Activation('relu')(x) x = Conv3D(81, 1, 1, 'same')(x) x = Activation('sigmoid')(x) attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x) x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention) return multiply([x, cost_volume]), attention
def define_AttMLFNet(sz_input, sz_input2, view_n, learning_rate): """ 4 branches inputs""" input_list = [] for i in range(len(view_n) * 4): input_list.append(Input(shape=(sz_input, sz_input2, 1))) """ 4 branches features""" feature_extraction_layer = feature_extraction(sz_input, sz_input2) feature_list = [] for i in range(len(view_n) * 4): feature_list.append(feature_extraction_layer(input_list[i])) feature_v_list = [] feature_h_list = [] feature_45_list = [] feature_135_list = [] for i in range(9): feature_h_list.append(feature_list[i]) for i in range(9, 18): feature_v_list.append(feature_list[i]) for i in range(18, 27): feature_45_list.append(feature_list[i]) for i in range(27, len(feature_list)): feature_135_list.append(feature_list[i]) """ cost volume """ cv_h = Lambda(_get_h_CostVolume_)(feature_h_list) cv_v = Lambda(_get_v_CostVolume_)(feature_v_list) cv_45 = Lambda(_get_45_CostVolume_)(feature_45_list) cv_135 = Lambda(_get_135_CostVolume_)(feature_135_list) """ intra branch """ cv_h_3d, cv_h_ca = to_3d_h(cv_h) cv_v_3d, cv_v_ca = to_3d_v(cv_v) cv_45_3d, cv_45_ca = to_3d_45(cv_45) cv_135_3d, cv_135_ca = to_3d_135(cv_135) """ inter branch """ cv, attention_4 = branch_attention( multiply([cv_h_3d, cv_v_3d, cv_45_3d, cv_135_3d]), cv_h_ca, cv_v_ca, cv_45_ca, cv_135_ca) """ cost volume regression """ cost = basic(cv) cost = Lambda(lambda x: K.permute_dimensions(K.squeeze(x, -1), (0, 2, 3, 1)))(cost) pred = Activation('softmax')(cost) pred = Lambda(disparityregression)(pred) model = Model(inputs=input_list, outputs=[pred]) model.summary() opt = Adam(lr=learning_rate) model.compile(optimizer=opt, loss='mae') return model
def channel_attention_mirror(cost_volume): x = GlobalAveragePooling3D()(cost_volume) x = Lambda( lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x) x = Conv3D(170, 1, 1, 'same')(x) x = Activation('relu')(x) x = Conv3D(25, 1, 1, 'same')(x) x = Activation('sigmoid')(x) x = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 5, 5)))(x) x = Lambda(lambda y: tf.pad(y, [[0, 0], [0, 4], [0, 4]], 'REFLECT'))(x) attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x) x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention) return multiply([x, cost_volume]), attention
def channel_attention(cost_volume): x = GlobalAveragePooling3D()(cost_volume) x = Lambda( lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x) x = Conv3D(170, 1, 1, 'same')(x) x = Activation('relu')(x) x = Conv3D(15, 1, 1, 'same')(x) # [B, 1, 1, 1, 15] x = Activation('sigmoid')(x) # 15 -> 25 # 0 1 2 3 4 # 5 6 7 8 # 9 10 11 # 12 13 # 14 # # 0 1 2 3 4 # 1 5 6 7 8 # 2 6 9 10 11 # 3 7 10 12 13 # 4 8 11 13 14 x = Lambda(lambda y: K.concatenate([ y[:, :, :, :, 0:5], y[:, :, :, :, 1:2], y[:, :, :, :, 5:9], y[:, :, :, :, 2:3], y[:, :, :, :, 6:7], y[:, :, :, :, 9:12], y[:, :, :, :, 3:4], y[:, :, :, :, 7:8], y[:, :, :, :, 10:11], y[:, :, :, :, 12:14], y[:, :, :, :, 4:5], y[:, :, :, :, 8:9], y[:, :, :, :, 11:12], y[:, :, :, :, 13:15] ], axis=-1))(x) x = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 5, 5)))(x) x = Lambda(lambda y: tf.pad(y, [[0, 0], [0, 4], [0, 4]], 'REFLECT'))(x) attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x) x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention) return multiply([x, cost_volume]), attention