示例#1
0
def test_shape_cutoffs(distances):
    cosine = CosineCutoff()
    mollifier = MollifierCutoff()
    hard = HardCutoff()
    inputs = [distances]
    out_shape = list(distances.shape)

    assert_equal_shape(cosine, inputs, out_shape)
    assert_equal_shape(mollifier, inputs, out_shape)
    assert_equal_shape(hard, inputs, out_shape)
示例#2
0
def test_cutoff_mollifier():
    # mollifier cutoff with radius 2.3
    cutoff = MollifierCutoff(cutoff=2.3)
    # check cutoff radius
    assert abs(2.3 - cutoff.cutoff) < 1.0e-12
    # tensor of zeros
    dist = torch.zeros((4, 1, 1))
    assert torch.allclose(torch.ones(4, 1, 1),
                          cutoff(dist),
                          atol=0.0,
                          rtol=1.0e-7)
    # random tensor with elements in [0, 1)
    torch.manual_seed(42)
    dist = torch.rand((1, 3, 9), dtype=torch.float)
    # check cutoff values
    expt = torch.exp(1.0 - 1.0 / (1.0 - (dist / 2.3)**2))
    assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7)
    # compute cutoff values and expected values
    comp = cutoff(3.8 * dist)
    expt = torch.exp(1.0 - 1.0 / (1.0 - (3.8 * dist / 2.3)**2))
    expt[3.8 * dist >= 2.3] = 0.0
    assert torch.allclose(expt, comp, atol=0.0, rtol=1.0e-7)
示例#3
0
def test_cutoff_mollifier_default():
    # mollifier cutoff with default radius
    cutoff = MollifierCutoff()
    # check cutoff radius
    assert abs(5.0 - cutoff.cutoff) < 1.0e-12
    # tensor of zeros
    dist = torch.zeros((5, 2, 3))
    assert torch.allclose(torch.ones(5, 2, 3),
                          cutoff(dist),
                          atol=0.0,
                          rtol=1.0e-7)
    # random tensor with elements in [0, 1)
    torch.manual_seed(42)
    dist = torch.rand((20, 1), dtype=torch.float)
    # check cutoff values
    expt = torch.exp(1.0 - 1.0 / (1.0 - (dist / 5.0)**2))
    assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7)
    # compute cutoff values and expected values
    comp = cutoff(6.0 * dist)
    expt = torch.exp(1.0 - 1.0 / (1.0 - (6.0 * dist / 5.0)**2))
    expt[6.0 * dist >= 5.0] = 0.0
    assert torch.allclose(expt, comp, atol=0.0, rtol=1.0e-7)