def attention_model(classes, backbone, shape): if backbone == 'Xception': stream1 = Xception(include_top=False, weights='imagenet', input_shape=shape) stream2 = Xception(include_top=False, weights='imagenet', input_shape=shape) elif backbone == 'MobileNetV3_Large': stream1 = MobileNetV3_Large(shape, classes).build() stream2 = MobileNetV3_Large(shape, classes).build() else: # MobileNetV3_Small stream1 = MobileNetV3_Small(shape, classes).build() stream2 = MobileNetV3_Small(shape, classes).build() stream1.name = 'stream1' stream2.name = 'stream2' input1 = Input(shape) input2 = Input(shape) output1 = stream1(input1) output2 = stream2(input2) if backbone == 'Xception': output1 = GlobalAveragePooling2D(name='avg_pool_1')(output1) output2 = GlobalAveragePooling2D(name='avg_pool_2')(output2) # stream1 = Flatten()(stream1) # stream2 = Flatten()(stream2) # print(stream1.shape) output = Attention(size=output1.shape[1])([output1, output2]) # print(output.shape) if classes==1: output = Dense(classes, activation='sigmoid', name='predictions')(output) else: output = Dense(classes, activation='softmax', name='predictions')(output) # print(output.shape) return Model(inputs=[input1, input2], outputs=output)