def test_vgg_gpu(): prediction = torch.rand(test_shape[0], 3, *test_shape[2:]).to('cuda:0') f = ContextualLoss(use_vgg=True).to('cuda:0') loss = f(prediction, prediction) assert loss.shape == torch.Size([])
def test_module(): prediction = torch.rand(*test_shape) f = ContextualLoss() loss = f(prediction, prediction) assert loss.shape == torch.Size([])
def test_module_gpu(): prediction = torch.rand(*test_shape).to('cuda:0') f = ContextualLoss().to('cuda:0') loss = f(prediction, prediction) assert loss.shape == torch.Size([])