def test_builder_bond_constraints(ala2, ctx):
    # import logging
    # logger = logging.getLogger('bgflow')
    # logger.setLevel(logging.DEBUG)
    # logger.addHandler(
    #     logging.StreamHandler()
    # )
    pytest.importorskip("nflows")
    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(
        crd_transform,
        dim_augmented=0,
        n_constraints=2,
        remove_origin_and_rotation=True
    )
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    constrained_bond_indices = [0, 1]
    constrained_bond_lengths = [0.1, 0.1]
    assert builder.current_dims[BONDS] == (19, )
    assert builder.prior_dims[BONDS] == (19, )
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS))
    builder.add_map_to_ic_domains()
    builder.add_merge_constraints(constrained_bond_indices, constrained_bond_lengths)
    assert builder.current_dims[BONDS] == (21, )
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert samples.shape == (2, 66)
    generator.energy(samples)
    generator.kldiv(10)
def test_constrain_chirality(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    top = ala2.system.mdtraj_topology
    zmatrix, _ = bgmol.ZMatrixFactory(top).build_naive()
    crd_transform = GlobalInternalCoordinateTransformation(zmatrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    chiral_torsions = bgmol.is_chiral_torsion(crd_transform.torsion_indices, top)
    builder.add_constrain_chirality(chiral_torsions)
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(20)
    b, a, t, *_ = crd_transform.forward(samples)
    assert torch.all(t[:, chiral_torsions] >= 0.5)
    assert torch.all(t[:, chiral_torsions] <= 1.0)
def test_builder_multiple_crd(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    pytest.importorskip("nflows")

    # all-atom trafo
    z_matrix, fixed = bgmol.ZMatrixFactory(ala2.system.mdtraj_topology, cartesian=[6, 8, 10, 14, 16]).build_naive()
    crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)

    # cg trafo
    cg_top, _ = bgmol.build_fake_topology(5)
    cg_z_matrix, _ = bgmol.ZMatrixFactory(cg_top).build_naive()
    cg_crd_transform = GlobalInternalCoordinateTransformation(cg_z_matrix)
    cg_shape_info = ShapeDictionary.from_coordinate_transform(cg_crd_transform)
    CG_BONDS = cg_shape_info.replace(BONDS, "CG_BONDS")
    CG_ANGLES = cg_shape_info.replace(ANGLES, "CG_ANGLES")
    CG_TORSIONS = cg_shape_info.replace(TORSIONS, "CG_TORSIONS")
    shape_info.update(cg_shape_info)
    del shape_info[FIXED]

    # factory
    #marginals = InternalCoordinateMarginals(builder.current_dims)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(2):
        builder.add_condition(CG_TORSIONS, on=(CG_ANGLES, CG_BONDS))
        builder.add_condition((CG_ANGLES, CG_BONDS), on=CG_TORSIONS)
    marginals = InternalCoordinateMarginals(builder.current_dims, builder.ctx, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS)
    builder.add_map_to_ic_domains(marginals)
    builder.add_map_to_cartesian(cg_crd_transform, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS, out=FIXED)
    builder.transformer_type[FIXED] = bg.AffineTransformer
    for i in range(2):
        builder.add_condition(TORSIONS, on=FIXED)
        builder.add_condition(FIXED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, FIXED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    generator.energy(samples)
    generator.kldiv(10)
def test_builder_augmentation_and_global(ala2, ctx):
    pytest.importorskip("nflows")

    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform, dim_augmented=10)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(4):
        builder.add_condition(TORSIONS, on=AUGMENTED)
        builder.add_condition(AUGMENTED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, AUGMENTED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, AUGMENTED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert len(samples) == 2
    generator.energy(*samples)
    generator.kldiv(10)
def test_builder_api(ala2, ctx):
    pytest.importorskip("nflows")

    z_matrix = ala2.system.z_matrix
    fixed_atoms = ala2.system.rigid_block
    crd_transform = MixedCoordinateTransformation(torch.tensor(ala2.xyz, **ctx), z_matrix, fixed_atoms)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(4):
        builder.add_condition(TORSIONS, on=FIXED)
        builder.add_condition(FIXED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, FIXED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    generator.energy(samples)
    generator.kldiv(10)