Beispiel #1
0
def build_frontend(inputs,
                   frontend,
                   is_training=True,
                   pretrained_dir="models"):
    if frontend == 'ResNet50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                inputs, is_training=is_training, scope='resnet_v2_50')
            frontend_scope = 'resnet_v2_50'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'),
                var_list=slim.get_model_variables('resnet_v2_50'),
                ignore_missing_vars=True)
    elif frontend == 'ResNet101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_101(
                inputs, is_training=is_training, scope='resnet_v2_101')
            frontend_scope = 'resnet_v2_101'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'),
                var_list=slim.get_model_variables('resnet_v2_101'),
                ignore_missing_vars=True)
    else:
        raise ValueError(
            "Unsupported fronetnd model '%s'. This function only supports ResNet50, ResNet101, ResNet152, and MobileNetV2"
            % (frontend))

    return logits, end_points, frontend_scope, init_fn
Beispiel #2
0
def build_frontend(inputs, frontend, is_training=True):
    inputs = tf.to_float(inputs)
    if frontend == 'ResNet50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                inputs, is_training=is_training, scope='resnet_v2_50')
            frontend_scope = 'resnet_v2_50'
        
    elif frontend == 'ResNet101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_101(
                inputs, is_training=is_training, scope='resnet_v2_101')
            frontend_scope = 'resnet_v2_101'
    elif frontend == 'ResNet152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_152(
                inputs, is_training=is_training, scope='resnet_v2_152')
            frontend_scope = 'resnet_v2_152'
    elif frontend == 'MobileNetV2':
        with slim.arg_scope(mobilenet_v2.training_scope()):
            logits, end_points = mobilenet_v2.mobilenet(
                inputs, is_training=is_training, scope='mobilenet_v2', base_only=True)
            frontend_scope = 'mobilenet_v2'
    elif frontend == 'InceptionV4':
        with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
            logits, end_points = inception_v4.inception_v4(
                inputs, is_training=is_training, scope='inception_v4')
            frontend_scope = 'inception_v4'
    else:
        raise ValueError(
            "Unsupported fronetnd model '%s'. This function only supports ResNet50, ResNet101, ResNet152, and MobileNetV2" % (frontend))

    return logits, end_points, frontend_scope
def build_frontend(inputs, frontend, is_training=True, pretrained_dir="checkpoints"):
    if frontend == 'ResNet50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(inputs, is_training=is_training, scope='resnet_v2_50')
            frontend_scope='resnet_v2_50'
            init_fn = slim.assign_from_checkpoint_fn("D:/codepython/Segmentation/checkpoints/resnet_v2_50.ckpt", var_list=slim.get_model_variables('resnet_v2_50'), ignore_missing_vars=True)
    elif frontend == 'ResNet101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_101(inputs, is_training=is_training, scope='resnet_v2_101')
            frontend_scope='resnet_v2_101'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'), var_list=slim.get_model_variables('resnet_v2_101'), ignore_missing_vars=True)
    elif frontend == 'ResNet152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_152(inputs, is_training=is_training, scope='resnet_v2_152')
            frontend_scope='resnet_v2_152'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'), var_list=slim.get_model_variables('resnet_v2_152'), ignore_missing_vars=True)
    elif frontend == 'MobileNetV2':
        with slim.arg_scope(mobilenet_v2.training_scope()):
            logits, end_points = mobilenet_v2.mobilenet(inputs, is_training=is_training, scope='MobilenetV2', base_only=True)
            frontend_scope='MobilenetV2'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'mobilenet_v2_1.0_224.ckpt'), var_list=slim.get_model_variables('MobilenetV2'), ignore_missing_vars=True)
    elif frontend == 'InceptionV4':
        with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
            logits, end_points = inception_v4.inception_v4(inputs, is_training=is_training, scope='inception_v4')
            frontend_scope='inception_v4'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'inception_v4.ckpt'), var_list=slim.get_model_variables('inception_v4'), ignore_missing_vars=True)
    else:
        raise ValueError("Unsupported fronetnd model '%s'. This function only supports ResNet50, ResNet101, ResNet152, and MobileNetV2" % (frontend))

    return logits, end_points, frontend_scope, init_fn 
Beispiel #4
0
def build_frontend(inputs, frontend_config, is_training=True, reuse=False):
    frontend = frontend_config['frontend']
    pretrained_dir = frontend_config['pretrained_dir']

    if "ResNet50" == frontend and not os.path.isfile("pretrain/resnet_v2_50.ckpt"):
        download_checkpoints("ResNet50")
    if "ResNet101" == frontend and not os.path.isfile("pretrain/resnet_v2_101.ckpt"):
        download_checkpoints("ResNet101")
    if "ResNet152" == frontend and not os.path.isfile("pretrain/resnet_v2_152.ckpt"):
        download_checkpoints("ResNet152")
    if "MobileNetV2" == frontend and not os.path.isfile("pretrain/mobilenet_v2.ckpt.data-00000-of-00001"):
        download_checkpoints("MobileNetV2")
    if "InceptionV4" == frontend and not os.path.isfile("pretrain/inception_v4.ckpt"):
        download_checkpoints("InceptionV4")

    if frontend == 'ResNet50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(inputs, is_training=is_training, scope='resnet_v2_50', reuse=reuse)
            frontend_scope='resnet_v2_50'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'), var_list=slim.get_model_variables('resnet_v2_50'), ignore_missing_vars=True)
    elif frontend == 'ResNet101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_101(inputs, is_training=is_training, scope='resnet_v2_101', reuse=reuse)
            frontend_scope='resnet_v2_101'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'), var_list=slim.get_model_variables('resnet_v2_101'), ignore_missing_vars=True)
    elif frontend == 'ResNet152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_152(inputs, is_training=is_training, scope='resnet_v2_152', reuse=reuse)
            frontend_scope='resnet_v2_152'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'), var_list=slim.get_model_variables('resnet_v2_152'), ignore_missing_vars=True)
    elif frontend == 'MobileNetV2':
        with slim.arg_scope(mobilenet_v2.training_scope()):
            logits, end_points = mobilenet_v2.mobilenet(inputs, is_training=is_training, scope='mobilenet_v2', base_only=True, reuse=reuse)
            frontend_scope='mobilenet_v2'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'mobilenet_v2.ckpt'), var_list=slim.get_model_variables('mobilenet_v2'), ignore_missing_vars=True)
    elif frontend == 'InceptionV4':
        with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
            logits, end_points = inception_v4.inception_v4(inputs, is_training=is_training, scope='inception_v4', reuse=reuse)
            frontend_scope='inception_v4'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'inception_v4.ckpt'), var_list=slim.get_model_variables('inception_v4'), ignore_missing_vars=True)
    elif frontend == 'DenseNet121':
        with slim.arg_scope(densenet.densenet_arg_scope()):
            logits, end_points = densenet.densenet121(inputs, is_training=is_training, scope='densenet121', reuse=reuse)
            frontend_scope ='densenet121'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'tf-densenet121/tf-densenet121.ckpt'), var_list=slim.get_model_variables('densenet121'), ignore_missing_vars=True)
    elif frontend == 'DenseNet161':
        with slim.arg_scope(densenet.densenet_arg_scope()):
            logits, end_points = densenet.densenet121(inputs, is_training=is_training, scope='densenet161', reuse=reuse)
            frontend_scope='densenet161'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'tf-densenet161.ckpt'), var_list=slim.get_model_variables('densenet161'), ignore_missing_vars=True)
    elif frontend == 'DenseNet169':
        with slim.arg_scope(densenet.densenet_arg_scope()):
            logits, end_points= densenet.densenet121(inputs, is_training=is_training, scope='densenet169', reuse=reuse)
            frontend_scope='densenet169'
            init_fn = slim.assign_from_checkpoint_fn(model_path=os.path.join(pretrained_dir, 'tf-densenet169.ckpt'), var_list=slim.get_model_variables('densenet169'), ignore_missing_vars=True)
    elif frontend == 'Xception39':
        with slim.arg_scope(xception.xception_arg_scope()):
            logits, end_points = xception.xception39(inputs, is_training=is_training, scope='xception39', reuse=reuse)
            frontend_scope='Xception39'
            init_fn = None
    else:
        raise ValueError("Unsupported fronetnd model '%s'. This function only supports ResNet50, ResNet101, ResNet152, and MobileNetV2" % (frontend))

    return logits, end_points, frontend_scope, init_fn