Exemple #1
0
def test_mse_regression_loss():

    # w/o target weight(default None)
    loss_cfg = dict(type='MSELoss')
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(1.))

    # w/ target weight
    loss_cfg = dict(type='MSELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(1.))
Exemple #2
0
def test_bone_loss():
    # w/o target weight(default None)
    loss_cfg = dict(type='BoneLoss', joint_parents=[0, 0, 1])
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
                             dtype=torch.float32)
    fake_label = fake_pred * 2
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(3**0.5))

    # w/ target weight
    loss_cfg = dict(type='BoneLoss',
                    joint_parents=[0, 0, 1],
                    use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    fake_weight = torch.ones((1, 2))
    assert torch.allclose(loss(fake_pred, fake_label, fake_weight),
                          torch.tensor(0.))

    fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
                             dtype=torch.float32)
    fake_label = fake_pred * 2
    fake_weight = torch.ones((1, 2))
    assert torch.allclose(loss(fake_pred, fake_label, fake_weight),
                          torch.tensor(3**0.5))
def test_wing_loss():
    from mmpose.models import build_loss

    # test WingLoss without target weight
    loss_cfg = dict(type='WingLoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(loss(fake_pred, fake_label, None), torch.tensor(.5))

    # test WingLoss with target weight
    loss_cfg = dict(type='WingLoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(loss(fake_pred, fake_label, torch.ones_like(fake_label)),
                    torch.tensor(.5))
Exemple #4
0
def test_smooth_l1_loss():
    # test SmoothL1Loss without target weight(default None)
    loss_cfg = dict(type='SmoothL1Loss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(.5))

    # test SmoothL1Loss with target weight
    loss_cfg = dict(type='SmoothL1Loss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(.5))
def test_bce_loss():
    from mmpose.models import build_loss

    # test BCE loss without target weight(None)
    loss_cfg = dict(type='BCELoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 2))
    fake_label = torch.zeros((1, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 2)) * 0.5
    fake_label = torch.zeros((1, 2))
    assert torch.allclose(loss(fake_pred, fake_label),
                          -torch.log(torch.tensor(0.5)))

    # test BCE loss with target weight
    loss_cfg = dict(type='BCELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.ones((1, 2)) * 0.5
    fake_label = torch.zeros((1, 2))
    fake_weight = torch.ones((1, 2))
    assert torch.allclose(loss(fake_pred, fake_label, fake_weight),
                          -torch.log(torch.tensor(0.5)))

    fake_weight[:, 0] = 0
    assert torch.allclose(loss(fake_pred, fake_label, fake_weight),
                          -0.5 * torch.log(torch.tensor(0.5)))

    fake_weight = torch.ones(1)
    assert torch.allclose(loss(fake_pred, fake_label, fake_weight),
                          -torch.log(torch.tensor(0.5)))
def test_mse_loss():
    from mmpose.models import build_loss

    # test MSE loss without target weight
    loss_cfg = dict(type='JointsMSELoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))

    fake_pred = torch.zeros((1, 2, 64, 64))
    fake_pred[0, 0] += 1
    fake_label = torch.zeros((1, 2, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))

    with pytest.raises(ValueError):
        loss_cfg = dict(type='JointsOHKMMSELoss')
        loss = build_loss(loss_cfg)
        fake_pred = torch.zeros((1, 3, 64, 64))
        fake_label = torch.zeros((1, 3, 64, 64))
        assert torch.allclose(loss(fake_pred, fake_label, None),
                              torch.tensor(0.))

    with pytest.raises(AssertionError):
        loss_cfg = dict(type='JointsOHKMMSELoss', topk=-1)
        loss = build_loss(loss_cfg)
        fake_pred = torch.zeros((1, 3, 64, 64))
        fake_label = torch.zeros((1, 3, 64, 64))
        assert torch.allclose(loss(fake_pred, fake_label, None),
                              torch.tensor(0.))

    loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
    loss = build_loss(loss_cfg)
    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))

    loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_pred[0, 0] += 1
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))
Exemple #7
0
def test_semi_supervision_loss():
    loss_cfg = dict(type='SemiSupervisionLoss',
                    joint_parents=[0, 0, 1],
                    warmup_iterations=1)
    loss = build_loss(loss_cfg)

    unlabeled_pose = torch.rand((1, 3, 3))
    unlabeled_traj = torch.ones((1, 1, 3))
    labeled_pose = unlabeled_pose.clone()
    fake_pred = dict(labeled_pose=labeled_pose,
                     unlabeled_pose=unlabeled_pose,
                     unlabeled_traj=unlabeled_traj)

    intrinsics = torch.tensor([[1, 1, 1, 1, 0.1, 0.1, 0.1, 0, 0]],
                              dtype=torch.float32)
    unlabled_target_2d = loss.project_joints(unlabeled_pose + unlabeled_traj,
                                             intrinsics)
    fake_label = dict(unlabeled_target_2d=unlabled_target_2d,
                      intrinsics=intrinsics)

    # test warmup
    losses = loss(fake_pred, fake_label)
    assert not losses

    # test semi-supervised loss
    losses = loss(fake_pred, fake_label)
    assert torch.allclose(losses['proj_loss'], torch.tensor(0.))
    assert torch.allclose(losses['bone_loss'], torch.tensor(0.))
def test_smoothl1_loss():
    # test MSE loss without target weight
    loss_cfg = dict(type='SmoothL1Loss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3))
    fake_label = torch.zeros((1, 3))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
def test_adaptive_wing_loss():
    # test Adaptive WingLoss without target weight
    loss_cfg = dict(type='AdaptiveWingLoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))

    # test WingLoss with target weight
    loss_cfg = dict(type='AdaptiveWingLoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.ones((1, 3, 64, 64))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones([1, 3, 1])), torch.tensor(0.))
Exemple #10
0
def test_rle_loss():
    # test RLELoss without target weight(default None)
    loss_cfg = dict(type='RLELoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with Q(error) changed to "Gaussian"(default "Laplace")
    loss_cfg = dict(type='RLELoss', q_dis='gaussian')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss._apply(fn)
    loss_cfg = dict(type='RLELoss', size_average=False)
    loss = build_loss(loss_cfg)
    loss.cpu()

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with size_average(default True) changed to False
    loss_cfg = dict(type='RLELoss', size_average=False)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with residual(default True) changed to False
    loss_cfg = dict(type='RLELoss', residual=False)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with target weight
    loss_cfg = dict(type='RLELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label, torch.ones_like(fake_label))

    fake_pred = torch.ones((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label, torch.ones_like(fake_label))
def test_multi_loss_factory():
    from mmpose.models import build_loss

    # test heatmap loss
    loss_cfg = dict(type='HeatmapLoss')
    loss = build_loss(loss_cfg)

    with pytest.raises(AssertionError):
        fake_pred = torch.zeros((2, 3, 64, 64))
        fake_label = torch.zeros((1, 3, 64, 64))
        fake_mask = torch.zeros((1, 64, 64))
        loss(fake_pred, fake_label, fake_mask)

    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    fake_mask = torch.zeros((1, 64, 64))
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    fake_mask = torch.zeros((1, 64, 64))
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    fake_mask = torch.ones((1, 64, 64))
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_mask), torch.tensor(1.))

    # test AE loss
    fake_tags = torch.zeros((1, 18, 1))
    fake_joints = torch.zeros((1, 3, 2, 2), dtype=torch.int)

    loss_cfg = dict(type='AELoss', loss_type='exp')
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
    assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))

    fake_tags[0, 0, 0] = 1.
    fake_tags[0, 10, 0] = 0.
    fake_joints[0, 0, 0, :] = torch.IntTensor((0, 1))
    fake_joints[0, 0, 1, :] = torch.IntTensor((10, 1))
    loss_cfg = dict(type='AELoss', loss_type='exp')
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
    assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.25))

    fake_tags[0, 0, 0] = 0
    fake_tags[0, 7, 0] = 1.
    fake_tags[0, 17, 0] = 1.
    fake_joints[0, 1, 0, :] = torch.IntTensor((7, 1))
    fake_joints[0, 1, 1, :] = torch.IntTensor((17, 1))

    loss_cfg = dict(type='AELoss', loss_type='exp')
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))

    loss_cfg = dict(type='AELoss', loss_type='max')
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))

    with pytest.raises(ValueError):
        loss_cfg = dict(type='AELoss', loss_type='min')
        loss = build_loss(loss_cfg)
        loss(fake_tags, fake_joints)

    # test MultiLossFactory
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='MultiLossFactory',
            num_joints=2,
            num_stages=1,
            ae_loss_type='exp',
            with_ae_loss=True,
            push_loss_factor=[0.001],
            pull_loss_factor=[0.001],
            with_heatmaps_loss=[True],
            heatmaps_loss_factor=[1.0])
        loss = build_loss(loss_cfg)
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='MultiLossFactory',
            num_joints=2,
            num_stages=1,
            ae_loss_type='exp',
            with_ae_loss=[True],
            push_loss_factor=0.001,
            pull_loss_factor=[0.001],
            with_heatmaps_loss=[True],
            heatmaps_loss_factor=[1.0])
        loss = build_loss(loss_cfg)
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='MultiLossFactory',
            num_joints=2,
            num_stages=1,
            ae_loss_type='exp',
            with_ae_loss=[True],
            push_loss_factor=[0.001],
            pull_loss_factor=0.001,
            with_heatmaps_loss=[True],
            heatmaps_loss_factor=[1.0])
        loss = build_loss(loss_cfg)
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='MultiLossFactory',
            num_joints=2,
            num_stages=1,
            ae_loss_type='exp',
            with_ae_loss=[True],
            push_loss_factor=[0.001],
            pull_loss_factor=[0.001],
            with_heatmaps_loss=True,
            heatmaps_loss_factor=[1.0])
        loss = build_loss(loss_cfg)
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='MultiLossFactory',
            num_joints=2,
            num_stages=1,
            ae_loss_type='exp',
            with_ae_loss=[True],
            push_loss_factor=[0.001],
            pull_loss_factor=[0.001],
            with_heatmaps_loss=[True],
            heatmaps_loss_factor=1.0)
        loss = build_loss(loss_cfg)
    loss_cfg = dict(
        type='MultiLossFactory',
        num_joints=17,
        num_stages=1,
        ae_loss_type='exp',
        with_ae_loss=[False],
        push_loss_factor=[0.001],
        pull_loss_factor=[0.001],
        with_heatmaps_loss=[False],
        heatmaps_loss_factor=[1.0])
    loss = build_loss(loss_cfg)
    fake_outputs = [torch.zeros((1, 34, 64, 64))]
    fake_heatmaps = [torch.zeros((1, 17, 64, 64))]
    fake_masks = [torch.ones((1, 64, 64))]
    fake_joints = [torch.zeros((1, 30, 17, 2))]
    heatmaps_losses, push_losses, pull_losses = \
        loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
    assert heatmaps_losses == [None]
    assert pull_losses == [None]
    assert push_losses == [None]
    loss_cfg = dict(
        type='MultiLossFactory',
        num_joints=17,
        num_stages=1,
        ae_loss_type='exp',
        with_ae_loss=[True],
        push_loss_factor=[0.001],
        pull_loss_factor=[0.001],
        with_heatmaps_loss=[True],
        heatmaps_loss_factor=[1.0])
    loss = build_loss(loss_cfg)
    heatmaps_losses, push_losses, pull_losses = \
        loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
    assert len(heatmaps_losses) == 1
Exemple #12
0
def test_gan_loss():
    """test gan loss."""
    with pytest.raises(NotImplementedError):
        loss_cfg = dict(type='GANLoss',
                        gan_type='test',
                        real_label_val=1.0,
                        fake_label_val=0.0,
                        loss_weight=1)
        _ = build_loss(loss_cfg)

    input_1 = torch.ones(1, 1)
    input_2 = torch.ones(1, 3, 6, 6) * 2

    # vanilla
    loss_cfg = dict(type='GANLoss',
                    gan_type='vanilla',
                    real_label_val=1.0,
                    fake_label_val=0.0,
                    loss_weight=2.0)
    gan_loss = build_loss(loss_cfg)
    loss = gan_loss(input_1, True, is_disc=False)
    assert_almost_equal(loss.item(), 0.6265233)
    loss = gan_loss(input_1, False, is_disc=False)
    assert_almost_equal(loss.item(), 2.6265232)
    loss = gan_loss(input_1, True, is_disc=True)
    assert_almost_equal(loss.item(), 0.3132616)
    loss = gan_loss(input_1, False, is_disc=True)
    assert_almost_equal(loss.item(), 1.3132616)

    # lsgan
    loss_cfg = dict(type='GANLoss',
                    gan_type='lsgan',
                    real_label_val=1.0,
                    fake_label_val=0.0,
                    loss_weight=2.0)
    gan_loss = build_loss(loss_cfg)
    loss = gan_loss(input_2, True, is_disc=False)
    assert_almost_equal(loss.item(), 2.0)
    loss = gan_loss(input_2, False, is_disc=False)
    assert_almost_equal(loss.item(), 8.0)
    loss = gan_loss(input_2, True, is_disc=True)
    assert_almost_equal(loss.item(), 1.0)
    loss = gan_loss(input_2, False, is_disc=True)
    assert_almost_equal(loss.item(), 4.0)

    # wgan
    loss_cfg = dict(type='GANLoss',
                    gan_type='wgan',
                    real_label_val=1.0,
                    fake_label_val=0.0,
                    loss_weight=2.0)
    gan_loss = build_loss(loss_cfg)
    loss = gan_loss(input_2, True, is_disc=False)
    assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, False, is_disc=False)
    assert_almost_equal(loss.item(), 4)
    loss = gan_loss(input_2, True, is_disc=True)
    assert_almost_equal(loss.item(), -2.0)
    loss = gan_loss(input_2, False, is_disc=True)
    assert_almost_equal(loss.item(), 2.0)

    # hinge
    loss_cfg = dict(type='GANLoss',
                    gan_type='hinge',
                    real_label_val=1.0,
                    fake_label_val=0.0,
                    loss_weight=2.0)
    gan_loss = build_loss(loss_cfg)
    loss = gan_loss(input_2, True, is_disc=False)
    assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, False, is_disc=False)
    assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, True, is_disc=True)
    assert_almost_equal(loss.item(), 0.0)
    loss = gan_loss(input_2, False, is_disc=True)
    assert_almost_equal(loss.item(), 3.0)
Exemple #13
0
def test_mesh_loss():
    """test mesh loss."""
    loss_cfg = dict(type='MeshLoss',
                    joints_2d_loss_weight=1,
                    joints_3d_loss_weight=1,
                    vertex_loss_weight=1,
                    smpl_pose_loss_weight=1,
                    smpl_beta_loss_weight=1,
                    img_res=256,
                    focal_length=5000)

    loss = build_loss(loss_cfg)

    smpl_pose = torch.zeros([1, 72], dtype=torch.float32)
    smpl_rotmat = batch_rodrigues(smpl_pose.view(-1, 3)).view(-1, 24, 3, 3)
    smpl_beta = torch.zeros([1, 10], dtype=torch.float32)
    camera = torch.tensor([[1, 0, 0]], dtype=torch.float32)
    vertices = torch.rand([1, 6890, 3], dtype=torch.float32)
    joints_3d = torch.ones([1, 24, 3], dtype=torch.float32)
    joints_2d = loss.project_points(joints_3d, camera) + (256 - 1) / 2

    fake_pred = {}
    fake_pred['pose'] = smpl_rotmat
    fake_pred['beta'] = smpl_beta
    fake_pred['camera'] = camera
    fake_pred['vertices'] = vertices
    fake_pred['joints_3d'] = joints_3d

    fake_gt = {}
    fake_gt['pose'] = smpl_pose
    fake_gt['beta'] = smpl_beta
    fake_gt['vertices'] = vertices
    fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32)
    fake_gt['joints_3d'] = joints_3d
    fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
    fake_gt['joints_2d'] = joints_2d
    fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)

    losses = loss(fake_pred, fake_gt)
    assert torch.allclose(losses['vertex_loss'], torch.tensor(0.))
    assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(0.))
    assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(0.))
    assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.))
    assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.))

    fake_pred = {}
    fake_pred['pose'] = smpl_rotmat + 1
    fake_pred['beta'] = smpl_beta + 1
    fake_pred['camera'] = camera
    fake_pred['vertices'] = vertices + 1
    fake_pred['joints_3d'] = joints_3d.clone()

    joints_3d_t = joints_3d.clone()
    joints_3d_t[:, 0] = joints_3d_t[:, 0] + 1
    fake_gt = {}
    fake_gt['pose'] = smpl_pose
    fake_gt['beta'] = smpl_beta
    fake_gt['vertices'] = vertices
    fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32)
    fake_gt['joints_3d'] = joints_3d_t
    fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
    fake_gt['joints_2d'] = joints_2d + (256 - 1) / 2
    fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)

    losses = loss(fake_pred, fake_gt)
    assert torch.allclose(losses['vertex_loss'], torch.tensor(1.))
    assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(1.))
    assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(1.))
    assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.5 / 24))
    assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.5))
def test_mse_loss():
    # test MSE loss without target weight
    loss_cfg = dict(type='JointsMSELoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))

    fake_pred = torch.zeros((1, 2, 64, 64))
    fake_pred[0, 0] += 1
    fake_label = torch.zeros((1, 2, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))

    with pytest.raises(ValueError):
        loss_cfg = dict(type='JointsOHKMMSELoss')
        loss = build_loss(loss_cfg)
        fake_pred = torch.zeros((1, 3, 64, 64))
        fake_label = torch.zeros((1, 3, 64, 64))
        assert torch.allclose(
            loss(fake_pred, fake_label, None), torch.tensor(0.))

    with pytest.raises(AssertionError):
        loss_cfg = dict(type='JointsOHKMMSELoss', topk=-1)
        loss = build_loss(loss_cfg)
        fake_pred = torch.zeros((1, 3, 64, 64))
        fake_label = torch.zeros((1, 3, 64, 64))
        assert torch.allclose(
            loss(fake_pred, fake_label, None), torch.tensor(0.))

    loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
    loss = build_loss(loss_cfg)
    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))

    loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 64, 64))
    fake_pred[0, 0] += 1
    fake_label = torch.zeros((1, 3, 64, 64))
    assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))

    loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    target_weight = torch.ones((1, 1, 1))
    assert torch.allclose(
        loss(fake_pred, fake_label, target_weight), torch.tensor(0.5))

    loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.ones((1, 3, 64, 64))
    fake_label = torch.zeros((1, 3, 64, 64))
    target_weight = torch.zeros((1, 1, 1))
    assert torch.allclose(
        loss(fake_pred, fake_label, target_weight), torch.tensor(0.))