示例#1
0
def construct_network():
    net_config = ResnetNoStageConfig()
    net_config.num_keypoints = 14
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 2  # depthmap_pred set 2 -> (:,3,:,:)  set 3 -> (:,6,:,:)
    net_config.num_layers = 34
    network = ResnetNoStage(net_config)
    return network, net_config
def construct_network():
    net_config = ResnetNoStageConfig()
    net_config.num_keypoints = 6
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 1
    net_config.num_layers = 34
    network = ResnetNoStage(net_config)
    return network, net_config
def construct_network():
    net_config = ResnetNoStageConfig()
    net_config.num_keypoints = 3
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 3  # For integral heatmap, depthmap and regress heatmap
    net_config.num_layers = 34
    network = ResnetNoStage(net_config)
    return network, net_config
示例#4
0
def construct_network(for_timeseries_data=False):
    net_config = ResnetNoStageConfig()
    net_config.num_keypoints = 2
    net_config.image_channels = 4
    net_config.depth_per_keypoint = 2
    net_config.num_layers = 18
    if for_timeseries_data:
        network = ResnetNoStageLSTM(net_config)
    else:
        network = ResnetNoStage(net_config)
    return network, net_config
    def test_output_size(self):
        from mankey.network.resnet_nostage import ResnetNoStageConfig, ResnetNoStage, init_from_modelzoo
        config = ResnetNoStageConfig()
        config.num_layers = 50
        config.num_keypoints = 10
        config.depth_per_keypoint = 1
        config.image_channels = 4
        net = ResnetNoStage(config)

        # Load from model zoo
        init_from_modelzoo(net, config)

        # Test on some dymmy image
        batch_size = 10
        img = torch.zeros((batch_size, config.image_channels, 256, 256))
        out = net(img)

        # Check it
        self.assertEqual(out.shape[0], batch_size)
        self.assertEqual(out.shape[1], config.num_keypoints * config.depth_per_keypoint)
        self.assertEqual(out.shape[2], 256 / 4)
        self.assertEqual(out.shape[3], 256 / 4)