def test_resnet_50layer(self): feature_extractor_text_proto = """ resnet { net_type: SINGLE_BRANCH net_depth: RESNET_50 conv_hyperparams { op: CONV regularizer { l2_regularizer { weight: 1e-4 } } initializer { variance_scaling_initializer { } } batch_norm { } } summarize_activations: false } """ convnet_proto = convnet_pb2.Convnet() text_format.Merge(feature_extractor_text_proto, convnet_proto) convnet_object = convnet_builder.build(convnet_proto, True) self.assertTrue(isinstance(convnet_object, resnet.Resnet50Layer)) test_image_shape = [2, 32, 128, 3] test_input_image = tf.random_uniform(test_image_shape, minval=0, maxval=255.0, dtype=tf.float32, seed=1) feature_maps = convnet_object.extract_features(test_input_image) self.assertTrue(len(feature_maps) == 1) print('Outputs of test_resnet_single_branch: {}'.format(feature_maps))
def test_crnn_net_three_branches(self): feature_extractor_text_proto = """ crnn_net { net_type: THREE_BRANCHES conv_hyperparams { op: CONV regularizer { l2_regularizer { weight: 1e-4 } } initializer { variance_scaling_initializer { } } batch_norm { } } summarize_activations: false } """ convnet_proto = convnet_pb2.Convnet() text_format.Merge(feature_extractor_text_proto, convnet_proto) convnet_object = convnet_builder.build(convnet_proto, True) self.assertTrue(isinstance(convnet_object, crnn_net.CrnnNet)) test_image_shape = [2, 32, 128, 3] test_input_image = tf.random_uniform(test_image_shape, minval=0, maxval=255.0, dtype=tf.float32, seed=1) feature_maps = convnet_object.extract_features(test_input_image) self.assertTrue(len(feature_maps) == 3) print( 'Outputs of test_crnn_net_three_branches: {}'.format(feature_maps))
def test_build_stn_convnet_tiny(self): text_proto = """ stn_convnet { conv_hyperparams { op: CONV regularizer { l2_regularizer { weight: 1e-4 } } initializer { variance_scaling_initializer { } } batch_norm { decay: 0.99 } } tiny: true } """ convnet_proto = convnet_pb2.Convnet() text_format.Merge(text_proto, convnet_proto) convnet_object = convnet_builder.build(convnet_proto, True) self.assertTrue(isinstance(convnet_object, stn_convnet.StnConvnetTiny)) test_image_shape = [2, 64, 128, 3] test_input_image = tf.random_uniform(test_image_shape, minval=0, maxval=255.0, dtype=tf.float32, seed=1) feature_maps = convnet_object.extract_features(test_input_image) self.assertTrue(len(feature_maps) == 1) print( 'Outputs of test_build_stn_convnet_tiny: {}'.format(feature_maps))
def build(config, is_training): if not isinstance(config, feature_extractor_pb2.FeatureExtractor): raise ValueError( 'config not of type feature_extractor_pb2.FeatureExtractor') convnet_object = convnet_builder.build(config.convnet, is_training) brnn_fn_list = [ functools.partial(bidirectional_rnn_builder.build, brnn_config, is_training) for brnn_config in config.bidirectional_rnn ] feature_extractor_object = feature_extractor.FeatureExtractor( convnet=convnet_object, brnn_fn_list=brnn_fn_list, summarize_activations=config.summarize_activations, is_training=is_training) return feature_extractor_object
def build(config, is_training): if not isinstance(config, spatial_transformer_pb2.SpatialTransformer): raise ValueError( 'config not of type spatial_transformer_pb2.SpatialTransformer') convnet_object = convnet_builder.build(config.convnet, is_training) fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams, is_training) return spatial_transformer.SpatialTransformer( convnet=convnet_object, fc_hyperparams=fc_hyperparams_object, localization_image_size=(config.localization_h, config.localization_w), output_image_size=(config.output_h, config.output_w), num_control_points=config.num_control_points, init_bias_pattern=config.init_bias_pattern, margins=(config.margin_x, config.margin_y), activation=config.activation, summarize_activations=config.summarize_activations)