def test_hourglass_feature_extractor(self): model = hourglass.HourglassNetwork( num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2) outputs = model(np.zeros((2, 64, 64, 3), dtype=np.float32)) self.assertEqual(outputs[0].shape, (2, 16, 16, 6)) self.assertEqual(outputs[1].shape, (2, 16, 16, 6))
def test_center_net_hourglass_feature_extractor(self): net = hourglass_network.HourglassNetwork( num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], input_channel_dims=4, channel_dims_per_stage=[6, 8, 10, 12, 14], num_hourglasses=2) model = hourglass.CenterNetHourglassFeatureExtractor(net) def graph_fn(): return model(tf.zeros((2, 64, 64, 3), dtype=np.float32)) outputs = self.execute(graph_fn, []) self.assertEqual(outputs[0].shape, (2, 16, 16, 6)) self.assertEqual(outputs[1].shape, (2, 16, 16, 6))