Exemplo n.º 1
0
def backbone_generator(params):
    """Generator function for various backbone models."""
    if params.architecture.backbone == 'resnet':
        resnet_params = params.resnet
        backbone_fn = resnet.Resnet(
            resnet_depth=resnet_params.resnet_depth,
            dropblock=dropblock_generator(params.dropblock),
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=resnet_params.init_drop_connect_rate)
    elif params.architecture.backbone == 'spinenet':
        spinenet_params = params.spinenet
        block_specs_list = None
        if spinenet_params.block_specs:
            block_specs_list = json.loads(spinenet_params.block_specs)
        backbone_fn = spinenet.spinenet_builder(
            model_id=spinenet_params.model_id,
            min_level=params.architecture.min_level,
            max_level=params.architecture.max_level,
            block_specs=spinenet.build_block_specs(block_specs_list),
            use_native_resize_op=spinenet_params.use_native_resize_op,
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=spinenet_params.init_drop_connect_rate)
    elif params.architecture.backbone == 'spinenet_mbconv':
        spinenet_mbconv_params = params.spinenet_mbconv
        block_specs_list = None
        if spinenet_mbconv_params.block_specs:
            block_specs_list = json.loads(spinenet_mbconv_params.block_specs)
        backbone_fn = spinenet_mbconv.spinenet_mbconv_builder(
            model_id=spinenet_mbconv_params.model_id,
            min_level=params.architecture.min_level,
            max_level=params.architecture.max_level,
            block_specs=spinenet_mbconv.build_block_specs(block_specs_list),
            use_native_resize_op=spinenet_mbconv_params.use_native_resize_op,
            se_ratio=spinenet_mbconv_params.se_ratio,
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=spinenet_mbconv_params.
            init_drop_connect_rate)
    else:
        raise ValueError('Backbone model %s is not supported.' %
                         params.architecture.backbone)

    return backbone_fn
Exemplo n.º 2
0
def backbone_generator(params):
    """Generator function for various backbone models."""
    if params.architecture.backbone == 'resnet':
        resnet_params = params.resnet
        backbone_fn = resnet.Resnet(
            resnet_depth=resnet_params.resnet_depth,
            dropblock=dropblock_generator(params.dropblock),
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=resnet_params.init_drop_connect_rate,
            space_to_depth_block_size=params.architecture.
            space_to_depth_block_size)
    elif params.architecture.backbone == 'spinenet':
        spinenet_params = params.spinenet
        backbone_fn = spinenet.spinenet_builder(
            model_id=spinenet_params.model_id,
            min_level=params.architecture.min_level,
            max_level=params.architecture.max_level,
            use_native_resize_op=spinenet_params.use_native_resize_op,
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=spinenet_params.init_drop_connect_rate)
    elif params.architecture.backbone == 'spinenet_mbconv':
        spinenet_mbconv_params = params.spinenet_mbconv
        backbone_fn = spinenet_mbconv.spinenet_mbconv_builder(
            model_id=spinenet_mbconv_params.model_id,
            min_level=params.architecture.min_level,
            max_level=params.architecture.max_level,
            use_native_resize_op=spinenet_mbconv_params.use_native_resize_op,
            se_ratio=spinenet_mbconv_params.se_ratio,
            activation=params.batch_norm_activation.activation,
            batch_norm_activation=batch_norm_activation_generator(
                params.batch_norm_activation),
            init_drop_connect_rate=spinenet_mbconv_params.
            init_drop_connect_rate)
    elif 'efficientnet' in params.architecture.backbone:
        backbone_fn = efficientnet.Efficientnet(params.architecture.backbone)
    else:
        raise ValueError('Backbone model %s is not supported.' %
                         params.architecture.backbone)

    return backbone_fn