示例#1
0
 def test_return_batch_norm_params_with_notrain_when_train_is_false(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l2_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
     batch_norm {
       decay: 0.7
       center: false
       scale: true
       epsilon: 0.03
       train: false
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   self.assertEqual(conv_scope_arguments['normalizer_fn'], layers.batch_norm)
   batch_norm_params = conv_scope_arguments['normalizer_params']
   self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
   self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
   self.assertFalse(batch_norm_params['center'])
   self.assertTrue(batch_norm_params['scale'])
   self.assertFalse(batch_norm_params['is_training'])
示例#2
0
def build(config, is_training):
    if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn):
        raise ValueError(
            'config not of type bidirectional_rnn_pb2.BidirectionalRnn')

    if config.static:
        brnn_class = bidirectional_rnn.StaticBidirectionalRnn
    else:
        brnn_class = bidirectional_rnn.DynamicBidirectionalRnn

    fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    rnn_regularizer_object = hyperparams_builder._build_regularizer(
        config.rnn_regularizer)
    fc_hyperparams_object = None
    if config.num_output_units > 0:
        if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC:
            raise ValueError('op type must be FC')
        fc_hyperparams_object = hyperparams_builder.build(
            config.fc_hyperparams, is_training)

    return brnn_class(fw_cell_object,
                      bw_cell_object,
                      rnn_regularizer=rnn_regularizer_object,
                      num_output_units=config.num_output_units,
                      fc_hyperparams=fc_hyperparams_object,
                      summarize_activations=config.summarize_activations)
示例#3
0
def _build_stn_resnet(config, is_training):
    if not isinstance(config, convnet_pb2.StnResnet):
        raise ValueError('config is not of type convnet_pb2.StnResnet')
    return resnet.ResnetForSTN(
        conv_hyperparams=hyperparams_builder.build(config.conv_hyperparams,
                                                   is_training),
        summarize_activations=config.summarize_activations,
        is_training=is_training)
示例#4
0
def _build_stn_convnet(config, is_training):
    if not isinstance(config, convnet_pb2.StnConvnet):
        raise ValueError('config is not of type convnet_pb2.StnConvnet')
    convnet_class = stn_convnet.StnConvnet
    if config.tiny == True:
        convnet_class = stn_convnet.StnConvnetTiny
    return convnet_class(conv_hyperparams=hyperparams_builder.build(
        config.conv_hyperparams, is_training),
                         summarize_activations=config.summarize_activations,
                         is_training=is_training)
示例#5
0
 def test_default_arg_scope_has_conv2d_transpose_op(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l1_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   self.assertTrue(self._get_scope_key(layers.conv2d_transpose) in scope)
示例#6
0
 def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
   conv_hyperparams_text_proto = """
     op: FC
     regularizer {
       l1_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   self.assertTrue(self._get_scope_key(layers.fully_connected) in scope)
示例#7
0
 def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l1_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   kwargs_1, kwargs_2, kwargs_3 = scope.values()
   self.assertDictEqual(kwargs_1, kwargs_2)
   self.assertDictEqual(kwargs_1, kwargs_3)
示例#8
0
 def test_use_relu_6_activation(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l2_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
     activation: RELU_6
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
示例#9
0
 def test_do_not_use_batch_norm_if_default(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l2_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
   self.assertEqual(conv_scope_arguments['normalizer_params'], None)
示例#10
0
def build(config, is_training):
    if not isinstance(config, spatial_transformer_pb2.SpatialTransformer):
        raise ValueError(
            'config not of type spatial_transformer_pb2.SpatialTransformer')

    convnet_object = convnet_builder.build(config.convnet, is_training)
    fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams,
                                                      is_training)
    return spatial_transformer.SpatialTransformer(
        convnet=convnet_object,
        fc_hyperparams=fc_hyperparams_object,
        localization_image_size=(config.localization_h, config.localization_w),
        output_image_size=(config.output_h, config.output_w),
        num_control_points=config.num_control_points,
        init_bias_pattern=config.init_bias_pattern,
        margins=(config.margin_x, config.margin_y),
        activation=config.activation,
        summarize_activations=config.summarize_activations)
示例#11
0
def _build_resnet(config, is_training):
    if not isinstance(config, convnet_pb2.ResNet):
        raise ValueError('config is not of type convnet_pb2.ResNet')

    if config.net_type != convnet_pb2.ResNet.SINGLE_BRANCH:
        raise ValueError('Only SINGLE_BRANCH is supported for ResNet')

    resnet_depth = config.net_depth
    if resnet_depth == convnet_pb2.ResNet.RESNET_50:
        resnet_class = resnet.Resnet50Layer
    else:
        raise ValueError('Unknown resnet depth: {}'.format(resnet_depth))

    conv_hyperparams = hyperparams_builder.build(config.conv_hyperparams,
                                                 is_training)
    return resnet_class(
        conv_hyperparams=conv_hyperparams,
        summarize_activations=config.summarize_activations,
        is_training=is_training,
    )
示例#12
0
 def test_variance_in_range_with_truncated_normal_initializer(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l2_regularizer {
       }
     }
     initializer {
       truncated_normal_initializer {
         mean: 0.0
         stddev: 0.8
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   initializer = conv_scope_arguments['weights_initializer']
   self._assert_variance_in_range(initializer, shape=[100, 40],
                                  variance=0.49, tol=1e-1)
示例#13
0
 def test_variance_in_range_with_variance_scaling_initializer_uniform(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l2_regularizer {
       }
     }
     initializer {
       variance_scaling_initializer {
         factor: 2.0
         mode: FAN_IN
         uniform: true
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   initializer = conv_scope_arguments['weights_initializer']
   self._assert_variance_in_range(initializer, shape=[100, 40],
                                  variance=2. / 100.)
示例#14
0
 def test_return_l1_regularized_weights(self):
   conv_hyperparams_text_proto = """
     regularizer {
       l1_regularizer {
         weight: 0.5
       }
     }
     initializer {
       truncated_normal_initializer {
       }
     }
   """
   conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
   text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
   scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
   conv_scope_arguments = list(scope.values())[0]
   regularizer = conv_scope_arguments['weights_regularizer']
   weights = np.array([1., -1, 4., 2.])
   with self.test_session() as sess:
     result = sess.run(regularizer(tf.constant(weights)))
   self.assertAllClose(np.abs(weights).sum() * 0.5, result)
示例#15
0
def _build_crnn_net(config, is_training):
    if not isinstance(config, convnet_pb2.CrnnNet):
        raise ValueError('config is not of type convnet_pb2.CrnnNet')

    if config.net_type == convnet_pb2.CrnnNet.SINGLE_BRANCH:
        crnn_net_class = crnn_net.CrnnNet
    elif config.net_type == convnet_pb2.CrnnNet.TWO_BRANCHES:
        crnn_net_class = crnn_net.CrnnNetTwoBranches
    elif config.net_type == convnet_pb2.CrnnNet.THREE_BRANCHES:
        crnn_net_class = crnn_net.CrnnNetThreeBranches
    else:
        raise ValueError('Unknown net_type: {}'.format(config.net_type))

    if config.tiny == True:
        crnn_net_class = crnn_net.CrnnNetTiny

    hyperparams_object = hyperparams_builder.build(config.conv_hyperparams,
                                                   is_training)

    return crnn_net_class(conv_hyperparams=hyperparams_object,
                          summarize_activations=config.summarize_activations,
                          is_training=is_training)