def test_vgg_gpu():
    prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0')
    f = ContextualLoss(use_vgg=True).to('cuda:0')
    loss = f(prediction, prediction)
    assert loss.shape == torch.Size([])
def test_module():
    prediction = torch.rand(*test_shape)
    f = ContextualLoss()
    loss = f(prediction, prediction)
    assert loss.shape == torch.Size([])
def test_module_gpu():
    prediction = torch.rand(*test_shape).to('cuda:0')
    f = ContextualLoss().to('cuda:0')
    loss = f(prediction, prediction)
    assert loss.shape == torch.Size([])