Ejemplo n.º 1
0
def test_invariance_pooled():
    model = _get_model(pooling='avg', return_type=0)
    R = rot(*torch.rand(3))
    if torch.cuda.is_available():
        R = R.cuda()
    out1, out2 = _get_outputs(model, R)

    assert torch.allclose(out2, out1, atol=TOL), \
        f'type-0 features should be invariant {get_max_diff(out1, out2)}'
Ejemplo n.º 2
0
def test_equivariance():
    model = _get_model()
    R = rot(*torch.rand(3))
    if torch.cuda.is_available():
        R = R.cuda()
    out1, out2 = _get_outputs(model, R)

    assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
        f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
    assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
        f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'