def test_hourglass_feature_extractor(self):

    model = hourglass.HourglassNetwork(
        num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6],
        channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2)
    outputs = model(np.zeros((2, 64, 64, 3), dtype=np.float32))
    self.assertEqual(outputs[0].shape, (2, 16, 16, 6))
    self.assertEqual(outputs[1].shape, (2, 16, 16, 6))
    def test_center_net_hourglass_feature_extractor(self):

        net = hourglass_network.HourglassNetwork(
            num_stages=4,
            blocks_per_stage=[2, 3, 4, 5, 6],
            input_channel_dims=4,
            channel_dims_per_stage=[6, 8, 10, 12, 14],
            num_hourglasses=2)

        model = hourglass.CenterNetHourglassFeatureExtractor(net)

        def graph_fn():
            return model(tf.zeros((2, 64, 64, 3), dtype=np.float32))

        outputs = self.execute(graph_fn, [])
        self.assertEqual(outputs[0].shape, (2, 16, 16, 6))
        self.assertEqual(outputs[1].shape, (2, 16, 16, 6))