def get_default_network(
        pose_resnet_config: PoseResNetModelConfig) -> Pose2DNetResnet:
    model = pose_resnet.get_pose_net(pose_resnet_config)

    # logger.info('=> loading model from {}'.format(cfg.pose_estimator.model_state_file))
    model.load_state_dict(torch.load(pose_resnet_config.model_state_file))

    model = model.cuda()
    model.eval()
    return Pose2DNetResnet(model, SkeletonStickman)
Ejemplo n.º 2
0
    buffer_size = 20
    action_names = [Action.IDLE.name, Action.WALK.name, Action.WAVE.name]
    use_action_recognition = True
    use_quick_n_dirty = False

    # Input Provider
    input_provider = WebcamProvider(camera_number=0,
                                    image_size=image_size,
                                    fps=fps)
    # input_provider = ImgDirProvider(
    #     "/media/disks/beta/records/real_cam/2019_03_13_Freilichtmuseum_Dashcam_01/full",
    #     image_size=image_size, fps=fps)
    fps_tracker = FPSTracker(average_over_seconds=1)

    # Pose Network
    pose_model = pose_resnet.get_pose_net(pose_resnet_config)

    logger.info('=> loading model from {}'.format(
        pose_resnet_config.model_state_file))
    pose_model.load_state_dict(torch.load(pose_resnet_config.model_state_file))
    pose_model = pose_model.cuda()
    pose_model.eval()
    pose_net = Pose2DNetResnet(pose_model, skeleton_type)
    pose_tracker = PoseTracker(image_size=image_size,
                               skeleton_type=skeleton_type)

    # Action Network
    action_model = ShuffleNetV2(input_size=32, n_class=3)
    state_dict = torch.load(ehpi_model_state_file)
    action_model.load_state_dict(state_dict)
    action_model.cuda()