def test_return_batch_norm_params_with_notrain_when_train_is_false(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } batch_norm { decay: 0.7 center: false scale: true epsilon: 0.03 train: false } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['normalizer_fn'], layers.batch_norm) batch_norm_params = conv_scope_arguments['normalizer_params'] self.assertAlmostEqual(batch_norm_params['decay'], 0.7) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertFalse(batch_norm_params['center']) self.assertTrue(batch_norm_params['scale']) self.assertFalse(batch_norm_params['is_training'])
def test_default_arg_scope_has_conv2d_transpose_op(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) self.assertTrue(self._get_scope_key(layers.conv2d_transpose) in scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): conv_hyperparams_text_proto = """ op: FC regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) self.assertTrue(self._get_scope_key(layers.fully_connected) in scope)
def test_use_relu_6_activation(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } activation: RELU_6 """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def test_do_not_use_batch_norm_if_default(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] self.assertEqual(conv_scope_arguments['normalizer_fn'], None) self.assertEqual(conv_scope_arguments['normalizer_params'], None)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters( self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) kwargs_1, kwargs_2, kwargs_3 = scope.values() self.assertDictEqual(kwargs_1, kwargs_2) self.assertDictEqual(kwargs_1, kwargs_3)
def test_return_l1_regularized_weights(self): conv_hyperparams_text_proto = """ regularizer { l1_regularizer { weight: 0.5 } } initializer { truncated_normal_initializer { } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] regularizer = conv_scope_arguments['weights_regularizer'] weights = np.array([1., -1, 4., 2.]) with self.test_session() as sess: result = sess.run(regularizer(tf.constant(weights))) self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_variance_in_range_with_truncated_normal_initializer(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { truncated_normal_initializer { mean: 0.0 stddev: 0.8 } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] initializer = conv_scope_arguments['weights_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=0.49, tol=1e-1)
def test_variance_in_range_with_variance_scaling_initializer_uniform(self): conv_hyperparams_text_proto = """ regularizer { l2_regularizer { } } initializer { variance_scaling_initializer { factor: 2.0 mode: FAN_IN uniform: true } } """ conv_hyperparams_proto = hyperparams_pb2.Hyperparams() text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) conv_scope_arguments = list(scope.values())[0] initializer = conv_scope_arguments['weights_initializer'] self._assert_variance_in_range(initializer, shape=[100, 40], variance=2. / 100.)