def test_return_batch_norm_params_with_notrain_when_train_is_false(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } batch_norm { decay: 0.7 center: false scale: true epsilon: 0.03 train: false } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['normalizer_fn'], layers.batch_norm) batch_norm_params = conv_scope_arguments['normalizer_params'] self.assertAlmostEqual(batch_norm_params['decay'], 0.7) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertFalse(batch_norm_params['center']) self.assertTrue(batch_norm_params['scale']) self.assertFalse(batch_norm_params['is_training'])
def build(config, is_training): if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn): raise ValueError( 'config not of type bidirectional_rnn_pb2.BidirectionalRnn') if config.static: brnn_class = bidirectional_rnn.StaticBidirectionalRnn else: brnn_class = bidirectional_rnn.DynamicBidirectionalRnn fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) rnn_regularizer_object = hyperparams_builder._build_regularizer( config.rnn_regularizer) fc_hyperparams_object = None if config.num_output_units > 0: if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC: raise ValueError('op type must be FC') fc_hyperparams_object = hyperparams_builder.build( config.fc_hyperparams, is_training) return brnn_class(fw_cell_object, bw_cell_object, rnn_regularizer=rnn_regularizer_object, num_output_units=config.num_output_units, fc_hyperparams=fc_hyperparams_object, summarize_activations=config.summarize_activations)
def _build_stn_resnet(config, is_training): if not isinstance(config, convnet_pb2.StnResnet): raise ValueError('config is not of type convnet_pb2.StnResnet') return resnet.ResnetForSTN( conv_hyperparams=hyperparams_builder.build(config.conv_hyperparams, is_training), summarize_activations=config.summarize_activations, is_training=is_training)
def _build_stn_convnet(config, is_training): if not isinstance(config, convnet_pb2.StnConvnet): raise ValueError('config is not of type convnet_pb2.StnConvnet') convnet_class = stn_convnet.StnConvnet if config.tiny == True: convnet_class = stn_convnet.StnConvnetTiny return convnet_class(conv_hyperparams=hyperparams_builder.build( config.conv_hyperparams, is_training), summarize_activations=config.summarize_activations, is_training=is_training)
def test_default_arg_scope_has_conv2d_transpose_op(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) self.assertTrue(self._get_scope_key(layers.conv2d_transpose) in scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): conv_hyperparams_text_proto = """ op: FC regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) self.assertTrue(self._get_scope_key(layers.fully_connected) in scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) kwargs_1, kwargs_2, kwargs_3 = scope.values() self.assertDictEqual(kwargs_1, kwargs_2) self.assertDictEqual(kwargs_1, kwargs_3)
def test_use_relu_6_activation(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } activation: RELU_6 """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def test_do_not_use_batch_norm_if_default(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['normalizer_fn'], None) self.assertEqual(conv_scope_arguments['normalizer_params'], None)
def build(config, is_training): if not isinstance(config, spatial_transformer_pb2.SpatialTransformer): raise ValueError( 'config not of type spatial_transformer_pb2.SpatialTransformer') convnet_object = convnet_builder.build(config.convnet, is_training) fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams, is_training) return spatial_transformer.SpatialTransformer( convnet=convnet_object, fc_hyperparams=fc_hyperparams_object, localization_image_size=(config.localization_h, config.localization_w), output_image_size=(config.output_h, config.output_w), num_control_points=config.num_control_points, init_bias_pattern=config.init_bias_pattern, margins=(config.margin_x, config.margin_y), activation=config.activation, summarize_activations=config.summarize_activations)
def _build_resnet(config, is_training): if not isinstance(config, convnet_pb2.ResNet): raise ValueError('config is not of type convnet_pb2.ResNet') if config.net_type != convnet_pb2.ResNet.SINGLE_BRANCH: raise ValueError('Only SINGLE_BRANCH is supported for ResNet') resnet_depth = config.net_depth if resnet_depth == convnet_pb2.ResNet.RESNET_50: resnet_class = resnet.Resnet50Layer else: raise ValueError('Unknown resnet depth: {}'.format(resnet_depth)) conv_hyperparams = hyperparams_builder.build(config.conv_hyperparams, is_training) return resnet_class( conv_hyperparams=conv_hyperparams, summarize_activations=config.summarize_activations, is_training=is_training, )
def test_variance_in_range_with_truncated_normal_initializer(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { mean: 0.0 stddev: 0.8 } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] initializer = conv_scope_arguments['weights_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=0.49, tol=1e-1)
def test_variance_in_range_with_variance_scaling_initializer_uniform(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { variance_scaling_initializer { factor: 2.0 mode: FAN_IN uniform: true } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] initializer = conv_scope_arguments['weights_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=2. / 100.)
def test_return_l1_regularized_weights(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { weight: 0.5 } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] regularizer = conv_scope_arguments['weights_regularizer'] weights = np.array([1., -1, 4., 2.]) with self.test_session() as sess: result = sess.run(regularizer(tf.constant(weights))) self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def _build_crnn_net(config, is_training): if not isinstance(config, convnet_pb2.CrnnNet): raise ValueError('config is not of type convnet_pb2.CrnnNet') if config.net_type == convnet_pb2.CrnnNet.SINGLE_BRANCH: crnn_net_class = crnn_net.CrnnNet elif config.net_type == convnet_pb2.CrnnNet.TWO_BRANCHES: crnn_net_class = crnn_net.CrnnNetTwoBranches elif config.net_type == convnet_pb2.CrnnNet.THREE_BRANCHES: crnn_net_class = crnn_net.CrnnNetThreeBranches else: raise ValueError('Unknown net_type: {}'.format(config.net_type)) if config.tiny == True: crnn_net_class = crnn_net.CrnnNetTiny hyperparams_object = hyperparams_builder.build(config.conv_hyperparams, is_training) return crnn_net_class(conv_hyperparams=hyperparams_object, summarize_activations=config.summarize_activations, is_training=is_training)