def attention_model(classes, backbone, shape):
    if backbone == 'Xception':
        stream1 = Xception(include_top=False, weights='imagenet', input_shape=shape)
        stream2 = Xception(include_top=False, weights='imagenet', input_shape=shape)
    elif backbone == 'MobileNetV3_Large':
        stream1 = MobileNetV3_Large(shape, classes).build()
        stream2 = MobileNetV3_Large(shape, classes).build()
    else: # MobileNetV3_Small
        stream1 = MobileNetV3_Small(shape, classes).build()
        stream2 = MobileNetV3_Small(shape, classes).build()

    stream1.name = 'stream1'
    stream2.name = 'stream2'

    input1 = Input(shape)
    input2 = Input(shape)
    output1 = stream1(input1)
    output2 = stream2(input2)
    
    if backbone == 'Xception':
        output1 = GlobalAveragePooling2D(name='avg_pool_1')(output1)
        output2 = GlobalAveragePooling2D(name='avg_pool_2')(output2)

    # stream1 = Flatten()(stream1)
    # stream2 = Flatten()(stream2)
    # print(stream1.shape)
    output = Attention(size=output1.shape[1])([output1, output2])
    # print(output.shape)
    if classes==1:
        output = Dense(classes, activation='sigmoid', name='predictions')(output)
    else:
        output = Dense(classes, activation='softmax', name='predictions')(output)
    # print(output.shape)
    return Model(inputs=[input1, input2], outputs=output)