def test_invariance_pooled(): model = _get_model(pooling='avg', return_type=0) R = rot(*torch.rand(3)) if torch.cuda.is_available(): R = R.cuda() out1, out2 = _get_outputs(model, R) assert torch.allclose(out2, out1, atol=TOL), \ f'type-0 features should be invariant {get_max_diff(out1, out2)}'
def test_equivariance(): model = _get_model() R = rot(*torch.rand(3)) if torch.cuda.is_available(): R = R.cuda() out1, out2 = _get_outputs(model, R) assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'