def tiny_yolo3_peleenet_body(inputs, num_anchors, num_classes):
    '''Create Tiny YOLO_v3 PeleeNet model CNN body in keras.'''
    peleenet = PeleeNet(input_tensor=inputs,
                        weights='imagenet',
                        include_top=False)
    print('backbone layers number: {}'.format(len(peleenet.layers)))

    # input: 416 x 416 x 3
    # re_lu_338(layer 365, final feature map): 13 x 13 x 704
    # re_lu_307(layer 265, end of stride 16) : 26 x 26 x 512
    # re_lu_266(layer 133, end of stride 8)  : 52 x 52 x 256

    # NOTE: activation layer name may different for TF1.x/2.x, so we
    # use index to fetch layer
    # f1: 13 x 13 x 704
    f1 = peleenet.layers[365].output
    # f2: 26 x 26 x 512
    f2 = peleenet.layers[265].output
    # f3: 52 x 52 x 256
    f3 = peleenet.layers[133].output

    f1_channel_num = 704
    f2_channel_num = 512

    y1, y2 = tiny_yolo3_predictions((f1, f2), (f1_channel_num, f2_channel_num),
                                    num_anchors, num_classes)

    return Model(inputs, [y1, y2])
def get_base_model(model_type, model_input_shape, weights='imagenet'):

    input_tensor = Input(shape=model_input_shape + (3, ), name='image_input')

    if model_type == 'mobilenet':
        model = MobileNet(input_tensor=input_tensor,
                          input_shape=model_input_shape + (3, ),
                          weights=weights,
                          pooling=None,
                          include_top=False,
                          alpha=0.5)
    elif model_type == 'mobilenetv2':
        model = MobileNetV2(input_tensor=input_tensor,
                            input_shape=model_input_shape + (3, ),
                            weights=weights,
                            pooling=None,
                            include_top=False,
                            alpha=0.5)
    elif model_type == 'mobilenetv3large':
        model = MobileNetV3Large(input_tensor=input_tensor,
                                 input_shape=model_input_shape + (3, ),
                                 weights=weights,
                                 pooling=None,
                                 include_top=False,
                                 alpha=0.75)
    elif model_type == 'mobilenetv3small':
        model = MobileNetV3Small(input_tensor=input_tensor,
                                 input_shape=model_input_shape + (3, ),
                                 weights=weights,
                                 pooling=None,
                                 include_top=False,
                                 alpha=0.75)
    elif model_type == 'peleenet':
        model = PeleeNet(input_tensor=input_tensor,
                         input_shape=model_input_shape + (3, ),
                         weights=weights,
                         pooling=None,
                         include_top=False)
    elif model_type == 'ghostnet':
        model = GhostNet(input_tensor=input_tensor,
                         input_shape=model_input_shape + (3, ),
                         weights=weights,
                         pooling=None,
                         include_top=False)
    elif model_type == 'squeezenet':
        model = SqueezeNet(input_tensor=input_tensor,
                           input_shape=model_input_shape + (3, ),
                           weights=weights,
                           pooling=None,
                           include_top=False)
    elif model_type == 'mobilevit_s':
        model = MobileViT_S(input_tensor=input_tensor,
                            input_shape=model_input_shape + (3, ),
                            weights=weights,
                            pooling=None,
                            include_top=False)
    elif model_type == 'mobilevit_xs':
        model = MobileViT_XS(input_tensor=input_tensor,
                             input_shape=model_input_shape + (3, ),
                             weights=weights,
                             pooling=None,
                             include_top=False)
    elif model_type == 'mobilevit_xxs':
        model = MobileViT_XXS(input_tensor=input_tensor,
                              input_shape=model_input_shape + (3, ),
                              weights=weights,
                              pooling=None,
                              include_top=False)
    elif model_type == 'resnet50':
        model = ResNet50(input_tensor=input_tensor,
                         input_shape=model_input_shape + (3, ),
                         weights=weights,
                         pooling=None,
                         include_top=False)
    elif model_type == 'simple_cnn':
        model = SimpleCNN(input_tensor=input_tensor,
                          input_shape=model_input_shape + (3, ),
                          weights=None,
                          pooling=None,
                          include_top=False)
    elif model_type == 'simple_cnn_lite':
        model = SimpleCNNLite(input_tensor=input_tensor,
                              input_shape=model_input_shape + (3, ),
                              weights=None,
                              pooling=None,
                              include_top=False)
    else:
        raise ValueError('Unsupported model type')
    return model