示例#1
0
def tfslim_vgg16():
    import tensorflow as tf
    from nets import nets_factory
    from preprocessing import vgg_preprocessing
    from model_tools.activations.tensorflow import load_resize_image
    tf.reset_default_graph()

    image_size = 224
    placeholder = tf.placeholder(dtype=tf.string, shape=[64])
    preprocess_image = lambda image: vgg_preprocessing.preprocess_image(
        image, image_size, image_size, resize_side_min=image_size)
    preprocess = lambda image_path: preprocess_image(
        load_resize_image(image_path, image_size))
    preprocess = tf.map_fn(preprocess, placeholder, dtype=tf.float32)

    model_ctr = nets_factory.get_network_fn('vgg_16',
                                            num_classes=1001,
                                            is_training=False)
    logits, endpoints = model_ctr(preprocess)

    session = tf.Session()
    session.run(tf.initialize_all_variables())
    return TensorflowSlimWrapper(identifier='tf-vgg16',
                                 labels_offset=1,
                                 endpoints=endpoints,
                                 inputs=placeholder,
                                 session=session)
示例#2
0
def tfslim_custom():
    from model_tools.activations.tensorflow import load_resize_image
    import tensorflow as tf
    slim = tf.contrib.slim
    tf.compat.v1.reset_default_graph()

    image_size = 224
    placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=[64])
    preprocess = lambda image_path: load_resize_image(image_path, image_size)
    preprocess = tf.map_fn(preprocess, placeholder, dtype=tf.float32)

    with tf.compat.v1.variable_scope('my_model', values=[preprocess]) as sc:
        end_points_collection = sc.original_name_scope + '_end_points'
        # Collect outputs for conv2d, fully_connected and max_pool2d.
        with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                            outputs_collections=[end_points_collection]):
            net = slim.conv2d(preprocess, 64, [11, 11], 4, padding='VALID', scope='conv1')
            net = slim.max_pool2d(net, [5, 5], 5, scope='pool1')
            net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
            net = slim.flatten(net, scope='flatten')
            net = slim.fully_connected(net, 1000, scope='logits')
            endpoints = slim.utils.convert_collection_to_dict(end_points_collection)

    session = tf.compat.v1.Session()
    session.run(tf.compat.v1.initialize_all_variables())
    return TensorflowSlimWrapper(identifier='tf-custom', labels_offset=0,
                                 endpoints=endpoints, inputs=placeholder, session=session)
示例#3
0
    def init(identifier,
             preprocessing_type,
             image_size,
             net_name=None,
             labels_offset=1,
             batch_size=64,
             model_ctr_kwargs=None):
        import tensorflow as tf
        from nets import nets_factory

        tf.compat.v1.reset_default_graph()
        placeholder = tf.compat.v1.placeholder(dtype=tf.string,
                                               shape=[batch_size])
        preprocess = TFSlimModel._init_preprocessing(placeholder,
                                                     preprocessing_type,
                                                     image_size=image_size)

        net_name = net_name or identifier
        model_ctr = nets_factory.get_network_fn(net_name,
                                                num_classes=labels_offset +
                                                1000,
                                                is_training=False)
        logits, endpoints = model_ctr(preprocess, **(model_ctr_kwargs or {}))
        if 'Logits' in endpoints:  # unify capitalization
            endpoints['logits'] = endpoints['Logits']
            del endpoints['Logits']

        session = tf.compat.v1.Session()
        TFSlimModel._restore_imagenet_weights(identifier, session)
        wrapper = TensorflowSlimWrapper(identifier=identifier,
                                        endpoints=endpoints,
                                        inputs=placeholder,
                                        session=session,
                                        batch_size=batch_size,
                                        labels_offset=labels_offset)
        wrapper.image_size = image_size
        return wrapper