コード例 #1
0
  def test_search_space_construction(self, ssd):
    model_spec = mobile_search_space_v3.get_search_space_spec(ssd)
    model_spec = test_utils.with_random_masks(model_spec)
    model = mobile_model_v3.get_model(
        model_spec, num_classes=1001, force_stateless_batch_norm=True)

    inputs = tf.random_normal(shape=[128, 224, 224, 3])
    model.build(inputs.shape)

    output, unused_endpoints = model.apply(inputs, training=True)
    self.assertEqual(output.shape, [128, 1001])
コード例 #2
0
  def test_output_shapes_with_variable_kernel_sizes(self):
    model_spec = mobile_search_space_v3.mobilenet_v3_like_search()
    model_spec = test_utils.with_random_masks(model_spec)
    model = mobile_model_v3.get_model(
        model_spec, num_classes=50, force_stateless_batch_norm=True)

    features = tf.ones([8, 224, 224, 3])
    model.build(features.shape)
    logits, endpoints = model.apply(features, training=True)

    self.assertEqual(logits.shape, tf.TensorShape([8, 50]))
    self.assertLen(endpoints, 5)
    endpoints[0].shape.assert_is_compatible_with([8, 112, 112, None])
    endpoints[1].shape.assert_is_compatible_with([8, 56, 56, None])
    endpoints[2].shape.assert_is_compatible_with([8, 28, 28, None])
    endpoints[3].shape.assert_is_compatible_with([8, 14, 14, None])
    endpoints[4].shape.assert_is_compatible_with([8, 7, 7, None])
コード例 #3
0
  def test_output_shapes_with_variable_filter_sizes(self):
    filter_multipliers = (0.5, 1.0, 2.0)

    model_spec = mobile_search_space_v3.mobilenet_v3_large()
    model_spec = search_space_utils.scale_conv_tower_spec(
        model_spec, multipliers=filter_multipliers)
    model_spec = test_utils.with_random_masks(model_spec)
    model = mobile_model_v3.get_model(model_spec, num_classes=50)

    features = tf.ones([8, 224, 224, 3])
    model.build(features.shape)
    logits, endpoints = model.apply(features, training=True)

    self.assertEqual(logits.shape, tf.TensorShape([8, 50]))
    self.assertEqual([x.shape for x in endpoints], [
        [8, 112, 112, int(16 * max(filter_multipliers))],
        [8, 56, 56, int(24 * max(filter_multipliers))],
        [8, 28, 28, int(40 * max(filter_multipliers))],
        [8, 14, 14, int(112 * max(filter_multipliers))],
        [8, 7, 7, int(160 * max(filter_multipliers))],
    ])