コード例 #1
0
def hourglass_32(channel_means, channel_stds, bgr_ordering):
  """The Hourglass-52 backbone for CenterNet."""

  network = hourglass_network.hourglass_32(num_channels=128)
  return CenterNetHourglassFeatureExtractor(
      network, channel_means=channel_means, channel_stds=channel_stds,
      bgr_ordering=bgr_ordering)
コード例 #2
0
ファイル: deepmac_meta_arch.py プロジェクト: ykate1998/models
def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
    """Get DeepMAC network model given a string type."""

    if name.startswith('hourglass'):
        if name == 'hourglass10':
            return hourglass_network.hourglass_10(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass20':
            return hourglass_network.hourglass_20(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass32':
            return hourglass_network.hourglass_32(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass52':
            return hourglass_network.hourglass_52(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass100':
            return hourglass_network.hourglass_100(num_init_channels,
                                                   initial_downsample=False)
        elif name == 'hourglass20_uniform_size':
            return hourglass_network.hourglass_20_uniform_size(
                num_init_channels)

        elif name == 'hourglass20_no_shortcut':
            return hourglass_network.hourglass_20_no_shortcut(
                num_init_channels)

    elif name == 'fully_connected':
        if not mask_size:
            raise ValueError('Mask size must be set.')
        return FullyConnectedMaskHead(num_init_channels, mask_size)

    elif name == 'embedding_projection':
        return tf.keras.layers.Lambda(lambda x: x)

    elif name.startswith('resnet'):
        return ResNetMaskNetwork(name, num_init_channels)

    raise ValueError('Unknown network type {}'.format(name))
コード例 #3
0
  def test_hourglass_32(self):
    net = hourglass.hourglass_32(2, downsample=False)
    self.assertEqual(hourglass.hourglass_depth(net), 32)

    outputs = net(tf.zeros((2, 32, 32, 3)))
    self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
コード例 #4
0
 def test_hourglass_32(self):
     net = hourglass.hourglass_32(2)
     self.assertEqual(hourglass.hourglass_depth(net), 32)