Exemple #1
0
def build(config, is_training):
    if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn):
        raise ValueError(
            'config not of type bidirectional_rnn_pb2.BidirectionalRnn')

    if config.static:
        brnn_class = bidirectional_rnn.StaticBidirectionalRnn
    else:
        brnn_class = bidirectional_rnn.DynamicBidirectionalRnn

    fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    rnn_regularizer_object = hyperparams_builder._build_regularizer(
        config.rnn_regularizer)
    fc_hyperparams_object = None
    if config.num_output_units > 0:
        if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC:
            raise ValueError('op type must be FC')
        fc_hyperparams_object = hyperparams_builder.build(
            config.fc_hyperparams, is_training)

    return brnn_class(fw_cell_object,
                      bw_cell_object,
                      rnn_regularizer=rnn_regularizer_object,
                      num_output_units=config.num_output_units,
                      fc_hyperparams=fc_hyperparams_object,
                      summarize_activations=config.summarize_activations)
Exemple #2
0
def _build_stn_resnet(config, is_training):
    if not isinstance(config, convnet_pb2.StnResnet):
        raise ValueError('config is not of type convnet_pb2.StnResnet')
    return resnet.ResnetForSTN(
        conv_hyperparams=hyperparams_builder.build(config.conv_hyperparams,
                                                   is_training),
        summarize_activations=config.summarize_activations,
        is_training=is_training)
Exemple #3
0
def _build_stn_convnet(config, is_training):
    if not isinstance(config, convnet_pb2.StnConvnet):
        raise ValueError('config is not of type convnet_pb2.StnConvnet')
    convnet_class = stn_convnet.StnConvnet
    if config.tiny == True:
        convnet_class = stn_convnet.StnConvnetTiny
    return convnet_class(conv_hyperparams=hyperparams_builder.build(
        config.conv_hyperparams, is_training),
                         summarize_activations=config.summarize_activations,
                         is_training=is_training)
Exemple #4
0
def build(config, is_training):
  if not isinstance(config, spatial_transformer_pb2.SpatialTransformer):
    raise ValueError('config not of type spatial_transformer_pb2.SpatialTransformer')
  
  convnet_object = convnet_builder.build(config.convnet, is_training)
  fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams, is_training)
  return spatial_transformer.SpatialTransformer(
    convnet=convnet_object,
    fc_hyperparams=fc_hyperparams_object,
    localization_image_size=(config.localization_h, config.localization_w),
    output_image_size=(config.output_h, config.output_w),
    num_control_points=config.num_control_points,
    init_bias_pattern=config.init_bias_pattern,
    margins=(config.margin_x, config.margin_y),
    activation=config.activation,
    summarize_activations=config.summarize_activations
  )
Exemple #5
0
def _build_resnet(config, is_training):
    if not isinstance(config, convnet_pb2.ResNet):
        raise ValueError('config is not of type convnet_pb2.ResNet')

    if config.net_type != convnet_pb2.ResNet.SINGLE_BRANCH:
        raise ValueError('Only SINGLE_BRANCH is supported for ResNet')

    resnet_depth = config.net_depth
    if resnet_depth == convnet_pb2.ResNet.RESNET_50:
        resnet_class = resnet.Resnet50Layer
    else:
        raise ValueError('Unknown resnet depth: {}'.format(resnet_depth))

    conv_hyperparams = hyperparams_builder.build(config.conv_hyperparams,
                                                 is_training)
    return resnet_class(
        conv_hyperparams=conv_hyperparams,
        summarize_activations=config.summarize_activations,
        is_training=is_training,
    )
Exemple #6
0
def _build_crnn_net(config, is_training):
    if not isinstance(config, convnet_pb2.CrnnNet):
        raise ValueError('config is not of type convnet_pb2.CrnnNet')

    if config.net_type == convnet_pb2.CrnnNet.SINGLE_BRANCH:
        crnn_net_class = crnn_net.CrnnNet
    elif config.net_type == convnet_pb2.CrnnNet.TWO_BRANCHES:
        crnn_net_class = crnn_net.CrnnNetTwoBranches
    elif config.net_type == convnet_pb2.CrnnNet.THREE_BRANCHES:
        crnn_net_class = crnn_net.CrnnNetThreeBranches
    else:
        raise ValueError('Unknown net_type: {}'.format(config.net_type))

    if config.tiny == True:
        crnn_net_class = crnn_net.CrnnNetTiny

    hyperparams_object = hyperparams_builder.build(config.conv_hyperparams,
                                                   is_training)

    return crnn_net_class(conv_hyperparams=hyperparams_object,
                          summarize_activations=config.summarize_activations,
                          is_training=is_training)