示例#1
0
def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
    """Get DeepMAC network model given a string type."""

    if name.startswith('hourglass'):
        if name == 'hourglass10':
            return hourglass_network.hourglass_10(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass20':
            return hourglass_network.hourglass_20(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass32':
            return hourglass_network.hourglass_32(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass52':
            return hourglass_network.hourglass_52(num_init_channels,
                                                  initial_downsample=False)
        elif name == 'hourglass100':
            return hourglass_network.hourglass_100(num_init_channels,
                                                   initial_downsample=False)
        elif name == 'hourglass20_uniform_size':
            return hourglass_network.hourglass_20_uniform_size(
                num_init_channels)

        elif name == 'hourglass20_no_shortcut':
            return hourglass_network.hourglass_20_no_shortcut(
                num_init_channels)

    elif name == 'fully_connected':
        if not mask_size:
            raise ValueError('Mask size must be set.')
        return FullyConnectedMaskHead(num_init_channels, mask_size)

    elif name == 'embedding_projection':
        return tf.keras.layers.Lambda(lambda x: x)

    elif name.startswith('resnet'):
        return ResNetMaskNetwork(name, num_init_channels)

    raise ValueError('Unknown network type {}'.format(name))
    def test_hourglass_100(self):
        net = hourglass.hourglass_100(2, initial_downsample=False)
        self.assertEqual(hourglass.hourglass_depth(net), 100)

        outputs = net(tf.zeros((2, 32, 32, 3)))
        self.assertEqual(outputs[0].shape, (2, 32, 32, 4))