コード例 #1
0
def load_vgg_16(model_dir, sess):
    model_url = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"

    filename = model_url.split("/")[-1]
    filepath = os.path.join(model_dir, filename.split(".tgz")[0])

    try:
        utils.download_pretrained_model_weights(model_url,
                                                filepath,
                                                unzip=True)
    except:
        print("Pre-training weights download failed!")

    model_file_name = "vgg_16.ckpt"
    model_path = os.path.join(filepath, model_file_name)

    resized_input_tensor = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 3])
    with tf.contrib.slim.arg_scope(vgg.vgg_arg_scope()):
        bottleneck_tensor, _ = vgg.vgg_16(resized_input_tensor,
                                          num_classes=None,
                                          global_pool=True)

    variable_restore_op = tf.contrib.slim.assign_from_checkpoint_fn(
        model_path,
        tf.contrib.slim.get_trainable_variables(),
        ignore_missing_vars=True)
    variable_restore_op(sess)

    bottleneck_tensor = tf.squeeze(bottleneck_tensor, axis=[1, 2])
    bottleneck_tensor_size = 4096

    return bottleneck_tensor, resized_input_tensor, bottleneck_tensor_size
コード例 #2
0
def load_densenet_201(model_dir):
    model_url = "https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5"

    filepath = os.path.join(model_dir, "densenet")

    try:
        utils.download_pretrained_model_weights(model_url,
                                                filepath,
                                                unzip=False)
    except:
        print("Pre-training weights download failed!")

    model_file_name = model_url.split("/")[-1]
    model_path = os.path.join(filepath, model_file_name)

    with tf.name_scope("DenseNet"):
        model = tf.keras.applications.densenet.DenseNet201(include_top=False,
                                                           weights=None,
                                                           pooling='avg')
        model.load_weights(model_path)

        bottleneck_tensor_size = 1920
        bottleneck_tensor = tf.placeholder(tf.float32,
                                           [None, bottleneck_tensor_size])
        resized_input_tensor = None

    return model, bottleneck_tensor, resized_input_tensor, bottleneck_tensor_size
コード例 #3
0
def load_inception_v3(model_dir):
    inception_url = "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
    bottleneck_tensor_name = "pool_3/_reshape:0"
    resized_input_tensor_name = "Mul:0"
    model_file_name = "classify_image_graph_def.pb"

    filename = model_url.split("/")[-1]
    filepath = os.path.join(model_dir, filename.split(".tgz")[0])

    try:
        utils.download_pretrained_model_weights(model_url,
                                                filepath,
                                                unzip=True)
    except:
        print("Pre-training weights download failed!")

    with tf.Graph().as_default() as graph:
        model_path = os.path.join(filepath, model_file_name)
        with gfile.FastGFile(model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
                graph_def,
                name="",
                return_elements=[
                    bottleneck_tensor_name,
                    resized_input_tensor_name,
                ]))

    bottleneck_tensor_size = 2048

    return graph, bottleneck_tensor, resized_input_tensor, bottleneck_tensor_size
コード例 #4
0
def load_mobilenet_v2(model_dir, sess):
    model_url = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz"

    filename = model_url.split("/")[-1]
    filepath = os.path.join(model_dir, filename.split(".tgz")[0])

    try:
        utils.download_pretrained_model_weights(model_url,
                                                filepath,
                                                unzip=True)
    except:
        print("Pre-training weights download failed!")

    model_file_name = "mobilenet_v2_1.4_224.ckpt"
    model_path = os.path.join(filepath, model_file_name)

    resized_input_tensor = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 3])
    with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope()):
        bottleneck_tensor, _ = mobilenet_v2.mobilenet(resized_input_tensor,
                                                      num_classes=None,
                                                      depth_multiplier=1.4)

    variable_restore_op = tf.contrib.slim.assign_from_checkpoint_fn(
        model_path,
        tf.contrib.slim.get_trainable_variables(),
        ignore_missing_vars=True)
    variable_restore_op(sess)

    # bottleneck_tensor = tf.squeeze(bottleneck_tensor, axis=[1, 2])
    bottleneck_tensor_size = 1792

    return bottleneck_tensor, resized_input_tensor, bottleneck_tensor_size