Ejemplo n.º 1
0
def test_global_ic_inversion(ctx, alanine_ics):
    tol = 1e-3 if ctx["dtype"] is torch.float32 else 1e-5
    _, z_matrix, _, positions = alanine_ics
    ic = GlobalInternalCoordinateTransformation(z_matrix).to(**ctx)
    positions = torch.tensor(positions, **ctx)
    *out, dlogp = ic.forward(positions)
    positions2, dlogp2 = ic.forward(*out, inverse=True)
    assert torch.allclose(positions, positions2, atol=tol)
    assert torch.allclose(dlogp, -dlogp2, atol=tol)
Ejemplo n.º 2
0
def test_builder_bond_constraints(ala2, ctx):
    # import logging
    # logger = logging.getLogger('bgflow')
    # logger.setLevel(logging.DEBUG)
    # logger.addHandler(
    #     logging.StreamHandler()
    # )
    pytest.importorskip("nflows")
    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(
        crd_transform,
        dim_augmented=0,
        n_constraints=2,
        remove_origin_and_rotation=True
    )
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    constrained_bond_indices = [0, 1]
    constrained_bond_lengths = [0.1, 0.1]
    assert builder.current_dims[BONDS] == (19, )
    assert builder.prior_dims[BONDS] == (19, )
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS))
    builder.add_map_to_ic_domains()
    builder.add_merge_constraints(constrained_bond_indices, constrained_bond_lengths)
    assert builder.current_dims[BONDS] == (21, )
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert samples.shape == (2, 66)
    generator.energy(samples)
    generator.kldiv(10)
Ejemplo n.º 3
0
def test_constrain_chirality(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    top = ala2.system.mdtraj_topology
    zmatrix, _ = bgmol.ZMatrixFactory(top).build_naive()
    crd_transform = GlobalInternalCoordinateTransformation(zmatrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    chiral_torsions = bgmol.is_chiral_torsion(crd_transform.torsion_indices, top)
    builder.add_constrain_chirality(chiral_torsions)
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(20)
    b, a, t, *_ = crd_transform.forward(samples)
    assert torch.all(t[:, chiral_torsions] >= 0.5)
    assert torch.all(t[:, chiral_torsions] <= 1.0)
Ejemplo n.º 4
0
def test_global_ic_properties(ctx):
    zmat = np.array([[0, -1, -1, -1], [1, 0, -1, -1], [2, 1, 0, -1],
                     [3, 2, 1, 0], [4, 3, 2, 1]])
    dim = 15
    batch_dim = 10

    ic = GlobalInternalCoordinateTransformation(zmat).to(**ctx)
    ics = ic.forward(torch.randn(batch_dim, dim, **ctx))
    assert (zmat[3:] == ic.z_matrix).all()
    assert len(ic.fixed_atoms) == 0
    assert ics[0].shape == (batch_dim, ic.dim_bonds)
    assert ics[1].shape == (batch_dim, ic.dim_angles)
    assert ics[2].shape == (batch_dim, ic.dim_torsions)
    assert ics[3].shape == (batch_dim, 1, 3)
    assert ics[4].shape == (batch_dim, 3)
    assert ic.dim_fixed == 0
    assert ic.normalize_angles
    assert (ic.bond_indices == zmat[1:, :2]).all()
    assert (ic.angle_indices == zmat[2:, :3]).all()
    assert (ic.torsion_indices == zmat[3:, :]).all()
Ejemplo n.º 5
0
def test_builder_multiple_crd(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    pytest.importorskip("nflows")

    # all-atom trafo
    z_matrix, fixed = bgmol.ZMatrixFactory(ala2.system.mdtraj_topology, cartesian=[6, 8, 10, 14, 16]).build_naive()
    crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)

    # cg trafo
    cg_top, _ = bgmol.build_fake_topology(5)
    cg_z_matrix, _ = bgmol.ZMatrixFactory(cg_top).build_naive()
    cg_crd_transform = GlobalInternalCoordinateTransformation(cg_z_matrix)
    cg_shape_info = ShapeDictionary.from_coordinate_transform(cg_crd_transform)
    CG_BONDS = cg_shape_info.replace(BONDS, "CG_BONDS")
    CG_ANGLES = cg_shape_info.replace(ANGLES, "CG_ANGLES")
    CG_TORSIONS = cg_shape_info.replace(TORSIONS, "CG_TORSIONS")
    shape_info.update(cg_shape_info)
    del shape_info[FIXED]

    # factory
    #marginals = InternalCoordinateMarginals(builder.current_dims)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(2):
        builder.add_condition(CG_TORSIONS, on=(CG_ANGLES, CG_BONDS))
        builder.add_condition((CG_ANGLES, CG_BONDS), on=CG_TORSIONS)
    marginals = InternalCoordinateMarginals(builder.current_dims, builder.ctx, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS)
    builder.add_map_to_ic_domains(marginals)
    builder.add_map_to_cartesian(cg_crd_transform, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS, out=FIXED)
    builder.transformer_type[FIXED] = bg.AffineTransformer
    for i in range(2):
        builder.add_condition(TORSIONS, on=FIXED)
        builder.add_condition(FIXED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, FIXED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    generator.energy(samples)
    generator.kldiv(10)
Ejemplo n.º 6
0
def test_builder_augmentation_and_global(ala2, ctx):
    pytest.importorskip("nflows")

    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform, dim_augmented=10)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(4):
        builder.add_condition(TORSIONS, on=AUGMENTED)
        builder.add_condition(AUGMENTED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, AUGMENTED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, AUGMENTED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert len(samples) == 2
    generator.energy(*samples)
    generator.kldiv(10)
Ejemplo n.º 7
0
def test_global_ic_transform(device, dtype):
    atol, rtol = TOLERANCES[device][dtype]
    torch.manual_seed(1)

    if dtype == torch.float32:
        atol = 1e-3
        rtol = 1e-3
    elif dtype == torch.float64:
        atol = 1e-5
        rtol = 1e-4

    N_SAMPLES = 1
    N_BONDS = 4
    N_ANGLES = 3
    N_TORSIONS = 2
    N_PARTICLES = 5

    _Z_MATRIX = np.array([[0, -1, -1, -1], [1, 0, -1, -1], [2, 1, 0, -1],
                          [3, 2, 1, 0], [4, 3, 2, 1]])

    for _ in range(N_REPETITIONS):

        for normalize_angles in [True, False]:

            ic = GlobalInternalCoordinateTransformation(
                _Z_MATRIX, normalize_angles=normalize_angles)

            # Test ic -> xyz -> ic reconstruction
            bonds = torch.randn(N_SAMPLES, N_BONDS, device=device,
                                dtype=dtype).exp()
            angles = torch.rand(N_SAMPLES,
                                N_ANGLES,
                                device=device,
                                dtype=dtype)
            torsions = torch.rand(N_SAMPLES,
                                  N_TORSIONS,
                                  device=device,
                                  dtype=dtype)

            if not normalize_angles:
                angles *= np.pi
                torsions = (2 * torsions - 1) * np.pi

            x0 = torch.randn(N_SAMPLES, 1, 3, device=device, dtype=dtype)

            alpha = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            beta = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            gamma = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            if not normalize_angles:
                alpha = alpha * 2 * np.pi - np.pi
                beta = 2 * beta - 1
                gamma = gamma * 2 * np.pi - np.pi
            orientation = torch.stack([alpha, beta, gamma], dim=-1)

            x, dlogp_fwd = ic(bonds,
                              angles,
                              torsions,
                              x0,
                              orientation,
                              inverse=True)
            (
                bonds_recon,
                angles_recon,
                torsions_recon,
                x0_recon,
                orientation_recon,
                dlogp_inv,
            ) = ic(x)

            failure_message = f"normalize_angles={normalize_angles};"

            # check valid reconstructions
            for name, truth, recon in zip(
                ["bonds", "angles", "torsions", "x0", "orientation"],
                [bonds, angles, torsions, x0, orientation],
                [
                    bonds_recon,
                    angles_recon,
                    torsions_recon,
                    x0_recon,
                    orientation_recon,
                ],
            ):
                assert torch.allclose(truth, recon, atol=atol,
                                      rtol=rtol), (failure_message +
                                                   f"{name} != {name}_recon;")
            assert torch.allclose(
                (dlogp_fwd + dlogp_inv).exp(),
                torch.ones_like(dlogp_fwd),
                atol=1e-3,
                rtol=1.0,
            ), failure_message

            # Test xyz -> ic -> xyz reconstruction
            x = torch.randn(N_SAMPLES,
                            N_PARTICLES * 3,
                            device=device,
                            dtype=dtype)

            *ics, dlogp_fwd = ic(x)
            x_recon, dlogp_inv = ic(*ics, inverse=True)

            assert torch.allclose(x, x_recon, atol=atol,
                                  rtol=rtol), failure_message
            assert torch.allclose(
                (dlogp_fwd + dlogp_inv).exp(),
                torch.ones_like(dlogp_fwd),
                atol=1e-3,
                rtol=1.0,
            ), failure_message

            # Test IC independence
            bonds, bonds_noise = torch.randn(2,
                                             N_SAMPLES,
                                             N_BONDS,
                                             device=device,
                                             dtype=dtype).exp()
            angles, angles_noise = torch.rand(2,
                                              N_SAMPLES,
                                              N_ANGLES,
                                              device=device,
                                              dtype=dtype)
            torsions, torsions_noise = torch.rand(2,
                                                  N_SAMPLES,
                                                  N_TORSIONS,
                                                  device=device,
                                                  dtype=dtype)
            x0, x0_noise = torch.randn(2,
                                       N_SAMPLES,
                                       1,
                                       3,
                                       device=device,
                                       dtype=dtype)

            alpha = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            beta = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            gamma = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            if not normalize_angles:
                alpha = alpha * 2 * np.pi - np.pi
                beta = 2 * beta - 1
                gamma = gamma * 2 * np.pi - np.pi
            orientation = torch.stack([alpha, beta, gamma], dim=-1)

            alpha_noise = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            beta_noise = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            gamma_noise = torch.rand(N_SAMPLES, device=device, dtype=dtype)
            if not normalize_angles:
                alpha_noise = alpha_noise * 2 * np.pi - np.pi
                beta_noise = 2 * beta_noise - 1
                gamma_noise = gamma_noise * 2 * np.pi - np.pi
            orientation_noise = torch.stack(
                [alpha_noise, beta_noise, gamma_noise], dim=-1)

            names = ["bonds", "angles", "torsions", "x0", "orientation"]
            orig = [bonds, angles, torsions, x0, orientation]
            noise = [
                bonds_noise,
                angles_noise,
                torsions_noise,
                x0_noise,
                orientation_noise,
            ]

            for i, name_noise in enumerate(names):
                noisy_ics = orig[:i] + [noise[i]] + orig[i + 1:]
                x, _ = ic(*noisy_ics, inverse=True)
                *noisy_ics_recon, _ = ic(x)
                for j, name_recon in enumerate(names):
                    if i != j:
                        assert torch.allclose(
                            orig[j], noisy_ics_recon[j], atol=atol,
                            rtol=rtol), (failure_message +
                                         f"{names[j]} != {name_recon}_recon")