def test_ce_loss(): from mmseg.models import build_loss # use_mask and use_sigmoid cannot be true at the same time with pytest.raises(AssertionError): loss_cfg = dict( type='CrossEntropyLoss', use_mask=True, use_sigmoid=True, loss_weight=1.0) build_loss(loss_cfg) # test loss with class weights loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, class_weight=[0.8, 0.2], loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) # test loss with class weights from file import os import tempfile import mmcv import numpy as np tmp_file = tempfile.NamedTemporaryFile() mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.npy', loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) tmp_file.close() os.remove(f'{tmp_file.name}.pkl') os.remove(f'{tmp_file.name}.npy') loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) fake_label = torch.ones(2, 8, 8).long() assert torch.allclose( loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) fake_label[:, 0, 0] = 255 assert torch.allclose( loss_cls(fake_pred, fake_label, ignore_index=255), torch.tensor(0.9354), atol=1e-4)
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): from mmseg.models import build_loss # use_mask and use_sigmoid cannot be true at the same time with pytest.raises(AssertionError): loss_cfg = dict(type='CrossEntropyLoss', use_mask=True, use_sigmoid=True, loss_weight=1.0) build_loss(loss_cfg) # test loss with simple case for ce/bce fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() loss_cls_cfg = dict(type='CrossEntropyLoss', use_sigmoid=use_sigmoid, loss_weight=1.0, avg_non_ignore=avg_non_ignore, loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) if use_sigmoid: assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) else: assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) # test loss with complicated case for ce/bce # when avg_non_ignore is False, `avg_factor` would not be calculated fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) fake_label = torch.ones(2, 8, 8).long() fake_label[:, 0, 0] = 255 fake_weight = None # extra test bce loss when pred.shape == label.shape if use_sigmoid and bce_input_same_dim: fake_pred = torch.randn(2, 10).float() fake_label = torch.rand(2, 10).float() fake_weight = torch.rand(2, 10) # set weight in forward function fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index fake_label[1, [0, 5, 8, 9]] = 255 loss_cls = build_loss(loss_cls_cfg) loss = loss_cls(fake_pred, fake_label, weight=fake_weight, ignore_index=255) if use_sigmoid: if fake_pred.dim() != fake_label.dim(): fake_label, weight, valid_mask = _expand_onehot_labels( labels=fake_label, label_weights=None, target_shape=fake_pred.shape, ignore_index=255) else: # should mask out the ignored elements valid_mask = ((fake_label >= 0) & (fake_label != 255)).float() weight = valid_mask torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred, fake_label.float(), reduction='none', weight=fake_weight) if avg_non_ignore: avg_factor = valid_mask.sum().item() torch_loss = (torch_loss * weight).sum() / avg_factor else: torch_loss = (torch_loss * weight).mean() else: if avg_non_ignore: torch_loss = torch.nn.functional.cross_entropy(fake_pred, fake_label, reduction='mean', ignore_index=255) else: torch_loss = torch.nn.functional.cross_entropy( fake_pred, fake_label, reduction='sum', ignore_index=255) / fake_label.numel() assert torch.allclose(loss, torch_loss) if use_sigmoid: # test loss with complicated case for ce/bce # when avg_non_ignore is False, `avg_factor` would not be calculated fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) fake_label = torch.ones(2, 8, 8).long() fake_label[:, 0, 0] = 255 fake_weight = torch.rand(2, 8, 8) loss_cls = build_loss(loss_cls_cfg) loss = loss_cls(fake_pred, fake_label, weight=fake_weight, ignore_index=255) if use_sigmoid: fake_label, weight, valid_mask = _expand_onehot_labels( labels=fake_label, label_weights=None, target_shape=fake_pred.shape, ignore_index=255) torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred, fake_label.float(), reduction='none', weight=fake_weight.unsqueeze(1).expand(fake_pred.shape)) if avg_non_ignore: avg_factor = valid_mask.sum().item() torch_loss = (torch_loss * weight).sum() / avg_factor else: torch_loss = (torch_loss * weight).mean() assert torch.allclose(loss, torch_loss) # test loss with class weights from file fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() import os import tempfile import mmcv import numpy as np tmp_file = tempfile.NamedTemporaryFile() mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file loss_cls_cfg = dict(type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0, loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file loss_cls_cfg = dict(type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.npy', loss_weight=1.0, loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) tmp_file.close() os.remove(f'{tmp_file.name}.pkl') os.remove(f'{tmp_file.name}.npy') loss_cls_cfg = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) # test `avg_non_ignore` without ignore index would not affect ce/bce loss # when reduction='sum'/'none'/'mean' loss_cls_cfg1 = dict(type='CrossEntropyLoss', use_sigmoid=use_sigmoid, reduction=reduction, loss_weight=1.0, avg_non_ignore=True) loss_cls1 = build_loss(loss_cls_cfg1) loss_cls_cfg2 = dict(type='CrossEntropyLoss', use_sigmoid=use_sigmoid, reduction=reduction, loss_weight=1.0, avg_non_ignore=False) loss_cls2 = build_loss(loss_cls_cfg2) assert torch.allclose( loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), atol=1e-4) # test ce/bce loss with ignore index and class weight # in 5-way classification if use_sigmoid: # test bce loss when pred.shape == or != label.shape if bce_input_same_dim: fake_pred = torch.randn(2, 10).float() fake_label = torch.rand(2, 10).float() class_weight = torch.rand(2, 10) else: fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) fake_label = torch.ones(2, 8, 8).long() class_weight = torch.randn(2, 21, 8, 8) fake_label, weight, valid_mask = _expand_onehot_labels( labels=fake_label, label_weights=None, target_shape=fake_pred.shape, ignore_index=-100) torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred, fake_label.float(), reduction='mean', pos_weight=class_weight) else: fake_pred = torch.randn(2, 5, 10).float() # 5-way classification fake_label = torch.randint(0, 5, (2, 10)).long() class_weight = torch.rand(5) class_weight /= class_weight.sum() torch_loss = torch.nn.functional.cross_entropy( fake_pred, fake_label, reduction='sum', weight=class_weight) / fake_label.numel() loss_cls_cfg = dict(type='CrossEntropyLoss', use_sigmoid=use_sigmoid, reduction='mean', class_weight=class_weight, loss_weight=1.0, avg_non_ignore=avg_non_ignore) loss_cls = build_loss(loss_cls_cfg) # test cross entropy loss has name `loss_ce` assert loss_cls.loss_name == 'loss_ce' # test avg_non_ignore is in extra_repr assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}' loss = loss_cls(fake_pred, fake_label) assert torch.allclose(loss, torch_loss) fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index fake_label[1, [0, 5, 8, 9]] = 10 loss = loss_cls(fake_pred, fake_label, ignore_index=10) if use_sigmoid: if avg_non_ignore: torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred[fake_label != 10], fake_label[fake_label != 10].float(), pos_weight=class_weight[fake_label != 10], reduction='mean') else: torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred[fake_label != 10], fake_label[fake_label != 10].float(), pos_weight=class_weight[fake_label != 10], reduction='sum') / fake_label.numel() else: if avg_non_ignore: torch_loss = torch.nn.functional.cross_entropy( fake_pred, fake_label, ignore_index=10, reduction='sum', weight=class_weight) / fake_label[fake_label != 10].numel() else: torch_loss = torch.nn.functional.cross_entropy( fake_pred, fake_label, ignore_index=10, reduction='sum', weight=class_weight) / fake_label.numel() assert torch.allclose(loss, torch_loss)
def test_lovasz_loss(): from mmseg.models import build_loss # loss_type should be 'binary' or 'multi_class' with pytest.raises(AssertionError): loss_cfg = dict(type='LovaszLoss', loss_type='Binary', reduction='none', loss_weight=1.0) build_loss(loss_cfg) # reduction should be 'none' when per_image is False. with pytest.raises(AssertionError): loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') build_loss(loss_cfg) # test lovasz loss with loss_type = 'multi_class' and per_image = False loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() lovasz_loss(logits, labels) # test lovasz loss with loss_type = 'multi_class' and per_image = True loss_cfg = dict(type='LovaszLoss', per_image=True, reduction='mean', class_weight=[1.0, 2.0, 3.0], loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() lovasz_loss(logits, labels, ignore_index=None) # test loss with class weights from file import os import tempfile import mmcv import numpy as np tmp_file = tempfile.NamedTemporaryFile() mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file loss_cfg = dict(type='LovaszLoss', per_image=True, reduction='mean', class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file loss_cfg = dict(type='LovaszLoss', per_image=True, reduction='mean', class_weight=f'{tmp_file.name}.npy', loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) tmp_file.close() os.remove(f'{tmp_file.name}.pkl') os.remove(f'{tmp_file.name}.npy') # test lovasz loss with loss_type = 'binary' and per_image = False loss_cfg = dict(type='LovaszLoss', loss_type='binary', reduction='none', loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() lovasz_loss(logits, labels) # test lovasz loss with loss_type = 'binary' and per_image = True loss_cfg = dict(type='LovaszLoss', loss_type='binary', per_image=True, reduction='mean', loss_weight=1.0) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() lovasz_loss(logits, labels, ignore_index=None)
def test_dice_lose(): from mmseg.models import build_loss # test dice loss with loss_type = 'multi_class' loss_cfg = dict(type='DiceLoss', reduction='none', class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, ignore_index=1, loss_name='loss_dice') dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 3, 4, 4) labels = (torch.rand(8, 4, 4) * 3).long() dice_loss(logits, labels) # test loss with class weights from file import os import tempfile import mmcv import numpy as np tmp_file = tempfile.NamedTemporaryFile() mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file loss_cfg = dict(type='DiceLoss', reduction='none', class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0, ignore_index=1, loss_name='loss_dice') dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file loss_cfg = dict(type='DiceLoss', reduction='none', class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0, ignore_index=1, loss_name='loss_dice') dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) tmp_file.close() os.remove(f'{tmp_file.name}.pkl') os.remove(f'{tmp_file.name}.npy') # test dice loss with loss_type = 'binary' loss_cfg = dict(type='DiceLoss', smooth=2, exponent=3, reduction='sum', loss_weight=1.0, ignore_index=0, loss_name='loss_dice') dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 2, 4, 4) labels = (torch.rand(8, 4, 4) * 2).long() dice_loss(logits, labels) # test dice loss has name `loss_dice` loss_cfg = dict(type='DiceLoss', smooth=2, exponent=3, reduction='sum', loss_weight=1.0, ignore_index=0, loss_name='loss_dice') dice_loss = build_loss(loss_cfg) assert dice_loss.loss_name == 'loss_dice'