Exemple #1
0
def resnet_v2_50(inputs,
                 config,
                 is_training=True,
                 scope='resnet_v2_50'):
    """Modified ResNet-50 model."""
    blocks = [
        resnet_v2.resnet_v2_block('block1',
                                  base_depth=config.block1_depth,
                                  num_units=config.block1_units,
                                  stride=config.block1_stride),
        resnet_v2.resnet_v2_block('block2',
                                  base_depth=config.block2_depth,
                                  num_units=config.block2_units,
                                  stride=config.block2_stride),
        resnet_v2.resnet_v2_block('block3',
                                  base_depth=config.block3_depth,
                                  num_units=config.block3_units,
                                  stride=config.block3_stride),
        resnet_v2.resnet_v2_block('block4',
                                  base_depth=config.block4_depth,
                                  num_units=config.block4_units,
                                  stride=config.block4_stride),
    ]
    return resnet_v2.resnet_v2(inputs,
                               blocks,
                               is_training=is_training,
                               global_pool=False,
                               include_root_block=True,
                               scope=scope)
Exemple #2
0
def resnet_12(inputs, num_classes, scope='resnet_12'):
    blocks = [
        resnet_v2.resnet_v2_block('block1',
                                  base_depth=64,
                                  num_units=2,
                                  stride=1),
        resnet_v2.resnet_v2_block('block2',
                                  base_depth=64,
                                  num_units=2,
                                  stride=1),
        resnet_v2.resnet_v2_block('block3',
                                  base_depth=64,
                                  num_units=2,
                                  stride=1),
        resnet_v2.resnet_v2_block('block4',
                                  base_depth=64,
                                  num_units=2,
                                  stride=1),
        resnet_v2.resnet_v2_block('block5',
                                  base_depth=64,
                                  num_units=2,
                                  stride=1)
    ]
    return resnet_v2.resnet_v2(inputs,
                               blocks,
                               num_classes,
                               is_training=True,
                               global_pool=True,
                               output_stride=None,
                               include_root_block=True,
                               reuse=None,
                               scope=scope)
def resnet_small(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 include_root_block=True,
                 reuse=None,
                 scope='resnet_small'):
    blocks = [
        resnet_v2.resnet_v2_block('block1',
                                  base_depth=32,
                                  num_units=2,
                                  stride=2),
        resnet_v2.resnet_v2_block('block2',
                                  base_depth=64,
                                  num_units=2,
                                  stride=2),
        resnet_v2.resnet_v2_block('block3',
                                  base_depth=128,
                                  num_units=2,
                                  stride=2),
        resnet_v2.resnet_v2_block('block4',
                                  base_depth=256,
                                  num_units=2,
                                  stride=2),
    ]
    return resnet_v2.resnet_v2(inputs,
                               blocks,
                               num_classes,
                               is_training=is_training,
                               global_pool=global_pool,
                               output_stride=output_stride,
                               include_root_block=include_root_block,
                               reuse=reuse,
                               scope=scope)
Exemple #4
0
def _parse_resnet_v2_block(shape: list, config: dict):
    """
    Function to parse resnet_v2_block which creates a preactivation bottleneck unit.
    Documentation can be found here: https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py

    Parameters
    ----------
        shape: list holding the shape of expected input.
        config: dict describing the block, holding 'base_depth', 'num_units' and 'stride' keys. Optional: 'scope'

    Returns
    -------
        layers: None
        variables: All trainable variables associated with the scope.
        function: lambda x: block[1](x, args.get('depth'), args.get('depth_bottleneck'), args.get('stride'))
        output_shape: shape of the block output.
    """
    scope_name = config.get('scope', str(uuid.uuid4()))
    scope = tf.VariableScope(tf.AUTO_REUSE, name=scope_name)
    base_depth = config.get('base_depth')
    num_units = config.get('num_units')
    stride = config.get('stride')

    block = resnet_v2.resnet_v2_block(scope, base_depth, num_units, stride)
    args = block.args[0]

    layers = None
    function = lambda x: block[1](x, args.get(
        'depth'), args.get('depth_bottleneck'), args.get('stride'))
    output_shape = function(tf.placeholder(tf.float32,
                                           shape=shape)).get_shape()
    variables = scope.trainable_variables()

    return layers, variables, function, output_shape
def resnet_v2_light(inputs,
                    num_classes=None,
                    is_training=True,
                    global_pool=True,
                    output_stride=None,
                    spatial_squeeze=True,
                    reuse=None,
                    scope='resnet_v2_light'):
    """ResNet-light model of AIlab. See resnet_v2() for arg and return description."""
    blocks = [
        resnet_v2.resnet_v2_block('block1',
                                  base_depth=16,
                                  num_units=3,
                                  stride=2),
        resnet_v2.resnet_v2_block('block2',
                                  base_depth=32,
                                  num_units=4,
                                  stride=2),
        resnet_v2.resnet_v2_block('block3',
                                  base_depth=64,
                                  num_units=8,
                                  stride=2),
        resnet_v2.resnet_v2_block('block4',
                                  base_depth=128,
                                  num_units=3,
                                  stride=1),
    ]

    return resnet_v2.resnet_v2(inputs,
                               blocks,
                               num_classes,
                               is_training=is_training,
                               global_pool=global_pool,
                               output_stride=output_stride,
                               include_root_block=True,
                               reuse=reuse,
                               scope=scope)
Exemple #6
0
def resnet_v2_50(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 reuse=None,
                 scope='resnet_v2_50'):
    """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
    blocks = [
        resnet_v2.resnet_v2_block('block4',
                                  base_depth=512,
                                  num_units=1,
                                  stride=1),
    ]
    return resnet_v2_f(inputs,
                       blocks,
                       num_classes,
                       is_training,
                       global_pool,
                       4,
                       include_root_block=True,
                       reuse=reuse,
                       scope=scope)