コード例 #1
0
 def test_resnet_50layer(self):
     feature_extractor_text_proto = """
 resnet {
   net_type: SINGLE_BRANCH
   net_depth: RESNET_50
   conv_hyperparams {
     op: CONV
     regularizer { l2_regularizer { weight: 1e-4 } }
     initializer { variance_scaling_initializer { } }
     batch_norm { }
   }
   summarize_activations: false
 }
 """
     convnet_proto = convnet_pb2.Convnet()
     text_format.Merge(feature_extractor_text_proto, convnet_proto)
     convnet_object = convnet_builder.build(convnet_proto, True)
     self.assertTrue(isinstance(convnet_object, resnet.Resnet50Layer))
     test_image_shape = [2, 32, 128, 3]
     test_input_image = tf.random_uniform(test_image_shape,
                                          minval=0,
                                          maxval=255.0,
                                          dtype=tf.float32,
                                          seed=1)
     feature_maps = convnet_object.extract_features(test_input_image)
     self.assertTrue(len(feature_maps) == 1)
     print('Outputs of test_resnet_single_branch: {}'.format(feature_maps))
コード例 #2
0
    def test_crnn_net_three_branches(self):
        feature_extractor_text_proto = """
    crnn_net {
      net_type: THREE_BRANCHES
      conv_hyperparams {
        op: CONV
        regularizer { l2_regularizer { weight: 1e-4 } }
        initializer { variance_scaling_initializer { } }
        batch_norm { }
      }
      summarize_activations: false
    }
    """
        convnet_proto = convnet_pb2.Convnet()
        text_format.Merge(feature_extractor_text_proto, convnet_proto)
        convnet_object = convnet_builder.build(convnet_proto, True)
        self.assertTrue(isinstance(convnet_object, crnn_net.CrnnNet))

        test_image_shape = [2, 32, 128, 3]
        test_input_image = tf.random_uniform(test_image_shape,
                                             minval=0,
                                             maxval=255.0,
                                             dtype=tf.float32,
                                             seed=1)
        feature_maps = convnet_object.extract_features(test_input_image)
        self.assertTrue(len(feature_maps) == 3)
        print(
            'Outputs of test_crnn_net_three_branches: {}'.format(feature_maps))
コード例 #3
0
 def test_build_stn_convnet_tiny(self):
     text_proto = """
 stn_convnet {
   conv_hyperparams {
     op: CONV
     regularizer { l2_regularizer { weight: 1e-4 } }
     initializer { variance_scaling_initializer { } }
     batch_norm { decay: 0.99 }
   }
   tiny: true
 }
 """
     convnet_proto = convnet_pb2.Convnet()
     text_format.Merge(text_proto, convnet_proto)
     convnet_object = convnet_builder.build(convnet_proto, True)
     self.assertTrue(isinstance(convnet_object, stn_convnet.StnConvnetTiny))
     test_image_shape = [2, 64, 128, 3]
     test_input_image = tf.random_uniform(test_image_shape,
                                          minval=0,
                                          maxval=255.0,
                                          dtype=tf.float32,
                                          seed=1)
     feature_maps = convnet_object.extract_features(test_input_image)
     self.assertTrue(len(feature_maps) == 1)
     print(
         'Outputs of test_build_stn_convnet_tiny: {}'.format(feature_maps))
コード例 #4
0
def build(config, is_training):
    if not isinstance(config, feature_extractor_pb2.FeatureExtractor):
        raise ValueError(
            'config not of type feature_extractor_pb2.FeatureExtractor')

    convnet_object = convnet_builder.build(config.convnet, is_training)
    brnn_fn_list = [
        functools.partial(bidirectional_rnn_builder.build, brnn_config,
                          is_training)
        for brnn_config in config.bidirectional_rnn
    ]
    feature_extractor_object = feature_extractor.FeatureExtractor(
        convnet=convnet_object,
        brnn_fn_list=brnn_fn_list,
        summarize_activations=config.summarize_activations,
        is_training=is_training)
    return feature_extractor_object
コード例 #5
0
def build(config, is_training):
    if not isinstance(config, spatial_transformer_pb2.SpatialTransformer):
        raise ValueError(
            'config not of type spatial_transformer_pb2.SpatialTransformer')

    convnet_object = convnet_builder.build(config.convnet, is_training)
    fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams,
                                                      is_training)
    return spatial_transformer.SpatialTransformer(
        convnet=convnet_object,
        fc_hyperparams=fc_hyperparams_object,
        localization_image_size=(config.localization_h, config.localization_w),
        output_image_size=(config.output_h, config.output_w),
        num_control_points=config.num_control_points,
        init_bias_pattern=config.init_bias_pattern,
        margins=(config.margin_x, config.margin_y),
        activation=config.activation,
        summarize_activations=config.summarize_activations)