def test_mse_regression_loss(): # w/o target weight(default None) loss_cfg = dict(type='MSELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(1.)) # w/ target weight loss_cfg = dict(type='MSELoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(1.))
def test_bone_loss(): # w/o target weight(default None) loss_cfg = dict(type='BoneLoss', joint_parents=[0, 0, 1]) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)) fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]], dtype=torch.float32) fake_label = fake_pred * 2 assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(3**0.5)) # w/ target weight loss_cfg = dict(type='BoneLoss', joint_parents=[0, 0, 1], use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 3)) fake_label = torch.zeros((1, 3, 3)) fake_weight = torch.ones((1, 2)) assert torch.allclose(loss(fake_pred, fake_label, fake_weight), torch.tensor(0.)) fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]], dtype=torch.float32) fake_label = fake_pred * 2 fake_weight = torch.ones((1, 2)) assert torch.allclose(loss(fake_pred, fake_label, fake_weight), torch.tensor(3**0.5))
def test_wing_loss(): from mmpose.models import build_loss # test WingLoss without target weight loss_cfg = dict(type='WingLoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.gt(loss(fake_pred, fake_label, None), torch.tensor(.5)) # test WingLoss with target weight loss_cfg = dict(type='WingLoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.gt(loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(.5))
def test_smooth_l1_loss(): # test SmoothL1Loss without target weight(default None) loss_cfg = dict(type='SmoothL1Loss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(.5)) # test SmoothL1Loss with target weight loss_cfg = dict(type='SmoothL1Loss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 2)) fake_label = torch.zeros((1, 3, 2)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones_like(fake_label)), torch.tensor(.5))
def test_bce_loss(): from mmpose.models import build_loss # test BCE loss without target weight(None) loss_cfg = dict(type='BCELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 2)) fake_label = torch.zeros((1, 2)) assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)) fake_pred = torch.ones((1, 2)) * 0.5 fake_label = torch.zeros((1, 2)) assert torch.allclose(loss(fake_pred, fake_label), -torch.log(torch.tensor(0.5))) # test BCE loss with target weight loss_cfg = dict(type='BCELoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 2)) * 0.5 fake_label = torch.zeros((1, 2)) fake_weight = torch.ones((1, 2)) assert torch.allclose(loss(fake_pred, fake_label, fake_weight), -torch.log(torch.tensor(0.5))) fake_weight[:, 0] = 0 assert torch.allclose(loss(fake_pred, fake_label, fake_weight), -0.5 * torch.log(torch.tensor(0.5))) fake_weight = torch.ones(1) assert torch.allclose(loss(fake_pred, fake_label, fake_weight), -torch.log(torch.tensor(0.5)))
def test_mse_loss(): from mmpose.models import build_loss # test MSE loss without target weight loss_cfg = dict(type='JointsMSELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.)) fake_pred = torch.zeros((1, 2, 64, 64)) fake_pred[0, 0] += 1 fake_label = torch.zeros((1, 2, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5)) with pytest.raises(ValueError): loss_cfg = dict(type='JointsOHKMMSELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) with pytest.raises(AssertionError): loss_cfg = dict(type='JointsOHKMMSELoss', topk=-1) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) loss_cfg = dict(type='JointsOHKMMSELoss', topk=2) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.)) loss_cfg = dict(type='JointsOHKMMSELoss', topk=2) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_pred[0, 0] += 1 fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))
def test_semi_supervision_loss(): loss_cfg = dict(type='SemiSupervisionLoss', joint_parents=[0, 0, 1], warmup_iterations=1) loss = build_loss(loss_cfg) unlabeled_pose = torch.rand((1, 3, 3)) unlabeled_traj = torch.ones((1, 1, 3)) labeled_pose = unlabeled_pose.clone() fake_pred = dict(labeled_pose=labeled_pose, unlabeled_pose=unlabeled_pose, unlabeled_traj=unlabeled_traj) intrinsics = torch.tensor([[1, 1, 1, 1, 0.1, 0.1, 0.1, 0, 0]], dtype=torch.float32) unlabled_target_2d = loss.project_joints(unlabeled_pose + unlabeled_traj, intrinsics) fake_label = dict(unlabeled_target_2d=unlabled_target_2d, intrinsics=intrinsics) # test warmup losses = loss(fake_pred, fake_label) assert not losses # test semi-supervised loss losses = loss(fake_pred, fake_label) assert torch.allclose(losses['proj_loss'], torch.tensor(0.)) assert torch.allclose(losses['bone_loss'], torch.tensor(0.))
def test_smoothl1_loss(): # test MSE loss without target weight loss_cfg = dict(type='SmoothL1Loss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3)) fake_label = torch.zeros((1, 3)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
def test_adaptive_wing_loss(): # test Adaptive WingLoss without target weight loss_cfg = dict(type='AdaptiveWingLoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) # test WingLoss with target weight loss_cfg = dict(type='AdaptiveWingLoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.ones((1, 3, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, torch.ones([1, 3, 1])), torch.tensor(0.))
def test_rle_loss(): # test RLELoss without target weight(default None) loss_cfg = dict(type='RLELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label) # test RLELoss with Q(error) changed to "Gaussian"(default "Laplace") loss_cfg = dict(type='RLELoss', q_dis='gaussian') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label) # test RLELoss._apply(fn) loss_cfg = dict(type='RLELoss', size_average=False) loss = build_loss(loss_cfg) loss.cpu() fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label) # test RLELoss with size_average(default True) changed to False loss_cfg = dict(type='RLELoss', size_average=False) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label) # test RLELoss with residual(default True) changed to False loss_cfg = dict(type='RLELoss', residual=False) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label) # test RLELoss with target weight loss_cfg = dict(type='RLELoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label, torch.ones_like(fake_label)) fake_pred = torch.ones((1, 3, 4)) fake_label = torch.zeros((1, 3, 2)) loss(fake_pred, fake_label, torch.ones_like(fake_label))
def test_multi_loss_factory(): from mmpose.models import build_loss # test heatmap loss loss_cfg = dict(type='HeatmapLoss') loss = build_loss(loss_cfg) with pytest.raises(AssertionError): fake_pred = torch.zeros((2, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) fake_mask = torch.zeros((1, 64, 64)) loss(fake_pred, fake_label, fake_mask) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) fake_mask = torch.zeros((1, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, fake_mask), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) fake_mask = torch.zeros((1, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, fake_mask), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) fake_mask = torch.ones((1, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, fake_mask), torch.tensor(1.)) # test AE loss fake_tags = torch.zeros((1, 18, 1)) fake_joints = torch.zeros((1, 3, 2, 2), dtype=torch.int) loss_cfg = dict(type='AELoss', loss_type='exp') loss = build_loss(loss_cfg) assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.)) assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.)) fake_tags[0, 0, 0] = 1. fake_tags[0, 10, 0] = 0. fake_joints[0, 0, 0, :] = torch.IntTensor((0, 1)) fake_joints[0, 0, 1, :] = torch.IntTensor((10, 1)) loss_cfg = dict(type='AELoss', loss_type='exp') loss = build_loss(loss_cfg) assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.)) assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.25)) fake_tags[0, 0, 0] = 0 fake_tags[0, 7, 0] = 1. fake_tags[0, 17, 0] = 1. fake_joints[0, 1, 0, :] = torch.IntTensor((7, 1)) fake_joints[0, 1, 1, :] = torch.IntTensor((17, 1)) loss_cfg = dict(type='AELoss', loss_type='exp') loss = build_loss(loss_cfg) assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.)) loss_cfg = dict(type='AELoss', loss_type='max') loss = build_loss(loss_cfg) assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.)) with pytest.raises(ValueError): loss_cfg = dict(type='AELoss', loss_type='min') loss = build_loss(loss_cfg) loss(fake_tags, fake_joints) # test MultiLossFactory with pytest.raises(AssertionError): loss_cfg = dict( type='MultiLossFactory', num_joints=2, num_stages=1, ae_loss_type='exp', with_ae_loss=True, push_loss_factor=[0.001], pull_loss_factor=[0.001], with_heatmaps_loss=[True], heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) with pytest.raises(AssertionError): loss_cfg = dict( type='MultiLossFactory', num_joints=2, num_stages=1, ae_loss_type='exp', with_ae_loss=[True], push_loss_factor=0.001, pull_loss_factor=[0.001], with_heatmaps_loss=[True], heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) with pytest.raises(AssertionError): loss_cfg = dict( type='MultiLossFactory', num_joints=2, num_stages=1, ae_loss_type='exp', with_ae_loss=[True], push_loss_factor=[0.001], pull_loss_factor=0.001, with_heatmaps_loss=[True], heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) with pytest.raises(AssertionError): loss_cfg = dict( type='MultiLossFactory', num_joints=2, num_stages=1, ae_loss_type='exp', with_ae_loss=[True], push_loss_factor=[0.001], pull_loss_factor=[0.001], with_heatmaps_loss=True, heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) with pytest.raises(AssertionError): loss_cfg = dict( type='MultiLossFactory', num_joints=2, num_stages=1, ae_loss_type='exp', with_ae_loss=[True], push_loss_factor=[0.001], pull_loss_factor=[0.001], with_heatmaps_loss=[True], heatmaps_loss_factor=1.0) loss = build_loss(loss_cfg) loss_cfg = dict( type='MultiLossFactory', num_joints=17, num_stages=1, ae_loss_type='exp', with_ae_loss=[False], push_loss_factor=[0.001], pull_loss_factor=[0.001], with_heatmaps_loss=[False], heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) fake_outputs = [torch.zeros((1, 34, 64, 64))] fake_heatmaps = [torch.zeros((1, 17, 64, 64))] fake_masks = [torch.ones((1, 64, 64))] fake_joints = [torch.zeros((1, 30, 17, 2))] heatmaps_losses, push_losses, pull_losses = \ loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints) assert heatmaps_losses == [None] assert pull_losses == [None] assert push_losses == [None] loss_cfg = dict( type='MultiLossFactory', num_joints=17, num_stages=1, ae_loss_type='exp', with_ae_loss=[True], push_loss_factor=[0.001], pull_loss_factor=[0.001], with_heatmaps_loss=[True], heatmaps_loss_factor=[1.0]) loss = build_loss(loss_cfg) heatmaps_losses, push_losses, pull_losses = \ loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints) assert len(heatmaps_losses) == 1
def test_gan_loss(): """test gan loss.""" with pytest.raises(NotImplementedError): loss_cfg = dict(type='GANLoss', gan_type='test', real_label_val=1.0, fake_label_val=0.0, loss_weight=1) _ = build_loss(loss_cfg) input_1 = torch.ones(1, 1) input_2 = torch.ones(1, 3, 6, 6) * 2 # vanilla loss_cfg = dict(type='GANLoss', gan_type='vanilla', real_label_val=1.0, fake_label_val=0.0, loss_weight=2.0) gan_loss = build_loss(loss_cfg) loss = gan_loss(input_1, True, is_disc=False) assert_almost_equal(loss.item(), 0.6265233) loss = gan_loss(input_1, False, is_disc=False) assert_almost_equal(loss.item(), 2.6265232) loss = gan_loss(input_1, True, is_disc=True) assert_almost_equal(loss.item(), 0.3132616) loss = gan_loss(input_1, False, is_disc=True) assert_almost_equal(loss.item(), 1.3132616) # lsgan loss_cfg = dict(type='GANLoss', gan_type='lsgan', real_label_val=1.0, fake_label_val=0.0, loss_weight=2.0) gan_loss = build_loss(loss_cfg) loss = gan_loss(input_2, True, is_disc=False) assert_almost_equal(loss.item(), 2.0) loss = gan_loss(input_2, False, is_disc=False) assert_almost_equal(loss.item(), 8.0) loss = gan_loss(input_2, True, is_disc=True) assert_almost_equal(loss.item(), 1.0) loss = gan_loss(input_2, False, is_disc=True) assert_almost_equal(loss.item(), 4.0) # wgan loss_cfg = dict(type='GANLoss', gan_type='wgan', real_label_val=1.0, fake_label_val=0.0, loss_weight=2.0) gan_loss = build_loss(loss_cfg) loss = gan_loss(input_2, True, is_disc=False) assert_almost_equal(loss.item(), -4.0) loss = gan_loss(input_2, False, is_disc=False) assert_almost_equal(loss.item(), 4) loss = gan_loss(input_2, True, is_disc=True) assert_almost_equal(loss.item(), -2.0) loss = gan_loss(input_2, False, is_disc=True) assert_almost_equal(loss.item(), 2.0) # hinge loss_cfg = dict(type='GANLoss', gan_type='hinge', real_label_val=1.0, fake_label_val=0.0, loss_weight=2.0) gan_loss = build_loss(loss_cfg) loss = gan_loss(input_2, True, is_disc=False) assert_almost_equal(loss.item(), -4.0) loss = gan_loss(input_2, False, is_disc=False) assert_almost_equal(loss.item(), -4.0) loss = gan_loss(input_2, True, is_disc=True) assert_almost_equal(loss.item(), 0.0) loss = gan_loss(input_2, False, is_disc=True) assert_almost_equal(loss.item(), 3.0)
def test_mesh_loss(): """test mesh loss.""" loss_cfg = dict(type='MeshLoss', joints_2d_loss_weight=1, joints_3d_loss_weight=1, vertex_loss_weight=1, smpl_pose_loss_weight=1, smpl_beta_loss_weight=1, img_res=256, focal_length=5000) loss = build_loss(loss_cfg) smpl_pose = torch.zeros([1, 72], dtype=torch.float32) smpl_rotmat = batch_rodrigues(smpl_pose.view(-1, 3)).view(-1, 24, 3, 3) smpl_beta = torch.zeros([1, 10], dtype=torch.float32) camera = torch.tensor([[1, 0, 0]], dtype=torch.float32) vertices = torch.rand([1, 6890, 3], dtype=torch.float32) joints_3d = torch.ones([1, 24, 3], dtype=torch.float32) joints_2d = loss.project_points(joints_3d, camera) + (256 - 1) / 2 fake_pred = {} fake_pred['pose'] = smpl_rotmat fake_pred['beta'] = smpl_beta fake_pred['camera'] = camera fake_pred['vertices'] = vertices fake_pred['joints_3d'] = joints_3d fake_gt = {} fake_gt['pose'] = smpl_pose fake_gt['beta'] = smpl_beta fake_gt['vertices'] = vertices fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32) fake_gt['joints_3d'] = joints_3d fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32) fake_gt['joints_2d'] = joints_2d fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32) losses = loss(fake_pred, fake_gt) assert torch.allclose(losses['vertex_loss'], torch.tensor(0.)) assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(0.)) assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(0.)) assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.)) assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.)) fake_pred = {} fake_pred['pose'] = smpl_rotmat + 1 fake_pred['beta'] = smpl_beta + 1 fake_pred['camera'] = camera fake_pred['vertices'] = vertices + 1 fake_pred['joints_3d'] = joints_3d.clone() joints_3d_t = joints_3d.clone() joints_3d_t[:, 0] = joints_3d_t[:, 0] + 1 fake_gt = {} fake_gt['pose'] = smpl_pose fake_gt['beta'] = smpl_beta fake_gt['vertices'] = vertices fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32) fake_gt['joints_3d'] = joints_3d_t fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32) fake_gt['joints_2d'] = joints_2d + (256 - 1) / 2 fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32) losses = loss(fake_pred, fake_gt) assert torch.allclose(losses['vertex_loss'], torch.tensor(1.)) assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(1.)) assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(1.)) assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.5 / 24)) assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.5))
def test_mse_loss(): # test MSE loss without target weight loss_cfg = dict(type='JointsMSELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.)) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.)) fake_pred = torch.zeros((1, 2, 64, 64)) fake_pred[0, 0] += 1 fake_label = torch.zeros((1, 2, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5)) with pytest.raises(ValueError): loss_cfg = dict(type='JointsOHKMMSELoss') loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, None), torch.tensor(0.)) with pytest.raises(AssertionError): loss_cfg = dict(type='JointsOHKMMSELoss', topk=-1) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose( loss(fake_pred, fake_label, None), torch.tensor(0.)) loss_cfg = dict(type='JointsOHKMMSELoss', topk=2) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.)) loss_cfg = dict(type='JointsOHKMMSELoss', topk=2) loss = build_loss(loss_cfg) fake_pred = torch.zeros((1, 3, 64, 64)) fake_pred[0, 0] += 1 fake_label = torch.zeros((1, 3, 64, 64)) assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5)) loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) target_weight = torch.ones((1, 1, 1)) assert torch.allclose( loss(fake_pred, fake_label, target_weight), torch.tensor(0.5)) loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True) loss = build_loss(loss_cfg) fake_pred = torch.ones((1, 3, 64, 64)) fake_label = torch.zeros((1, 3, 64, 64)) target_weight = torch.zeros((1, 1, 1)) assert torch.allclose( loss(fake_pred, fake_label, target_weight), torch.tensor(0.))