Ejemplo n.º 1
0
def get_segm_depth_model(input, output_segm, output_depth):
    img_input = input

    o_shape = Model(img_input, output_segm).output_shape
    i_shape = Model(img_input, output_segm).input_shape

    n_classes = 0
    output_height = 0
    output_width = 0
    input_height = 0
    input_width = 0

    if IMAGE_ORDERING == 'channels_first':
        output_height = o_shape[2]
        output_width = o_shape[3]
        input_height = i_shape[2]
        input_width = i_shape[3]
        output_segm = (Reshape((-1, output_height * output_width)))(output_segm)
        output_segm = (Permute((2, 1)))(output_segm)
        n_classes = o_shape[1]

    elif IMAGE_ORDERING == 'channels_last':
        output_height = o_shape[1]
        output_width = o_shape[2]
        input_height = i_shape[1]
        input_width = i_shape[2]
        n_classes = o_shape[3]

        output_segm = (Reshape((output_height * output_width, -1)))(output_segm)



    output_segm = (Activation('softmax', name="segm_pred"))(output_segm)
    output_depth = (Activation('sigmoid', name="depth_pred"))(output_depth)

    model = Model(inputs=img_input, outputs=[output_segm, output_depth])

    model.n_classes = n_classes
    model.input_height = input_height
    model.input_width = input_width
    model.output_height = output_height
    model.output_width = output_width
    model.train = MethodType(train, model)
    model.predict_segmentation = MethodType(predict, model)
    #model.predict_multiple = MethodType(predict_multiple, model)
    #model.evaluate_segmentation = MethodType(evaluate, model)

    return model