def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=2,
            weights_path='model_data/yolo_weights.h5'):
    '''create the training model'''
    K.clear_session() # get a new session
    image_input = Input(shape=(None, None, 3))
    h, w = input_shape
    num_anchors = len(anchors)

    y_true = [Input(shape=(h//{0:32, 1:16, 2:8}[l], w//{0:32, 1:16, 2:8}[l], \
        num_anchors//3, num_classes+5)) for l in range(3)]

    l_true = [Input(shape=(h//{0:32, 1:16, 2:8}[l], w//{0:32, 1:16, 2:8}[l], \
        num_anchors//3, num_classes+5)) for l in range(3)]

    model_body = mobilenetv2_yolo_body(image_input, num_anchors//3, num_classes)
    print('Create YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))

    if load_pretrained:
        model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
        print('Load weights {}.'.format(weights_path))
        if freeze_body in [1, 2]:
            # Freeze darknet53 body or freeze all but 3 output layers.
            num = (185, len(model_body.layers)-3)[freeze_body-1]
            for i in range(num): model_body.layers[i].trainable = False
            print('Freeze the first {} layers of total {} layers.'.format(num, len(model_body.layers)))
    
    for y in range(-3, 0):
        model_body.layers[y].name = "conv2d_output_" + str(h//{-3:32, -2:16, -1:8}[y])

    model_loss = Lambda(yolo_distill_loss, output_shape=(1,), name='yolo_distill_loss',
        arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5})(
        [*model_body.output, *y_true , *l_true])
    model = Model([model_body.input, *y_true , *l_true ], model_loss)

    return model
num_anchors = len(anchors)

#mobilenet_model = MobileNet(input_tensor=image_input,weights='imagenet')
#mobilenet_model = mobilenet_yolo_body(image_input, num_anchors//3, num_classes)
#plot(model, to_file='{}.png'.format("mobilenet_yolo"), show_shapes=True)
#mobilenet_model.summary()
#mobilenet_model.save_weights('empty_mobilenet.h5')

#mobilenetv2_model = MobileNetV2(input_tensor=image_input,weights='imagenet')
#mobilenetv2_model = mobilenetv2_yolo_body(image_input, num_anchors//3, num_classes)
#plot(model, to_file='{}.png'.format("mobilenet_yolo"), show_shapes=True)
#mobilenetv2_model.summary()
#mobilenetv2_model.save_weights('empty_mobilenetv2.h5')

#squeezenet_model = squeezenet_body( input_tensor = image_input )
mobilenetv2 = mobilenetv2_yolo_body(image_input, num_anchors // 3, num_classes)
mobilenetv2.summary()
mobilenetv2.save_weights('empty_mobilenet.h5')
plot(mobilenetv2,
     to_file='{}.png'.format("mobilenetv2_yolo"),
     show_shapes=True)

#squeezenet_model = squeezenet_body( input_tensor = image_input )
#squeezenet_model.summary()
#squeezenet_model = squeezenet_yolo_body(image_input, num_anchors//3, num_classes)
#plot(squeezenet_model , to_file='{}.png'.format("squeezenet_yolo"), show_shapes=True)
#squeezenet_model.summary()
#squeezenet_model.save_weights('empty_squeezenet.h5')

#tiny_model = tiny_darknet_yolo_body(image_input, num_anchors//3, num_classes)
#tiny_model.summary()