Exemplo n.º 1
0
def test_builder_add_layer_and_param_groups(ctx):
    shape_info = ShapeDictionary()
    shape_info[BONDS] = (10, )
    shape_info[ANGLES] = (20, )
    builder = BoltzmannGeneratorBuilder(shape_info, **ctx)
    # transform some fields
    builder.add_layer(
        CDFTransform(
            TruncatedNormalDistribution(torch.zeros(10, **ctx), lower_bound=-torch.tensor(np.infty)),
        ),
        what=[BONDS],
        inverse=True,
        param_groups=("group1", )
    )
    # transform all fields
    builder.add_layer(
        CouplingFlow(
            AffineTransformer(
                DenseNet([10, 20]), DenseNet([10, 20])
            )
        ),
        param_groups=("group1", "group2")
    )
    builder.targets[BONDS] = NormalDistribution(10, torch.zeros(10, **ctx))
    builder.targets[ANGLES] = NormalDistribution(20, torch.zeros(20, **ctx))
    generator = builder.build_generator().to(**ctx)
    assert builder.param_groups["group1"] == list(generator.parameters())
    assert builder.param_groups["group2"] == list(generator._flow._blocks[1].parameters())
    generator.sample(10)
    generator.kldiv(10)
Exemplo n.º 2
0
def test_builder_bond_constraints(ala2, ctx):
    # import logging
    # logger = logging.getLogger('bgflow')
    # logger.setLevel(logging.DEBUG)
    # logger.addHandler(
    #     logging.StreamHandler()
    # )
    pytest.importorskip("nflows")
    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(
        crd_transform,
        dim_augmented=0,
        n_constraints=2,
        remove_origin_and_rotation=True
    )
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    constrained_bond_indices = [0, 1]
    constrained_bond_lengths = [0.1, 0.1]
    assert builder.current_dims[BONDS] == (19, )
    assert builder.prior_dims[BONDS] == (19, )
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS))
    builder.add_map_to_ic_domains()
    builder.add_merge_constraints(constrained_bond_indices, constrained_bond_lengths)
    assert builder.current_dims[BONDS] == (21, )
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert samples.shape == (2, 66)
    generator.energy(samples)
    generator.kldiv(10)
Exemplo n.º 3
0
def test_icmarginals_inform_api(tmpdir, ctx, with_data):
    """API test"""
    bgmol = pytest.importorskip("bgmol")
    dataset = bgmol.datasets.Ala2Implicit1000Test(
        root=tmpdir,
        download=True,
        read=True
    )
    coordinate_transform = GlobalInternalCoordinateTransformation(
        bgmol.systems.ala2.DEFAULT_GLOBAL_Z_MATRIX
    )
    current_dims = ShapeDictionary()
    current_dims[BONDS] = (coordinate_transform.dim_bonds - dataset.system.system.getNumConstraints(), )
    current_dims[ANGLES] = (coordinate_transform.dim_angles, )
    marginals = InternalCoordinateMarginals(current_dims, ctx)
    if with_data:
        constrained_indices, _ = bgmol.bond_constraints(dataset.system.system, coordinate_transform)
        marginals.inform_with_data(
            torch.tensor(dataset.xyz, **ctx), coordinate_transform,
            constrained_bond_indices=constrained_indices
        )
    else:
        marginals.inform_with_force_field(
            dataset.system.system, coordinate_transform, 1000.,
        )
Exemplo n.º 4
0
def test_builder_multiple_crd(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    pytest.importorskip("nflows")

    # all-atom trafo
    z_matrix, fixed = bgmol.ZMatrixFactory(ala2.system.mdtraj_topology, cartesian=[6, 8, 10, 14, 16]).build_naive()
    crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)

    # cg trafo
    cg_top, _ = bgmol.build_fake_topology(5)
    cg_z_matrix, _ = bgmol.ZMatrixFactory(cg_top).build_naive()
    cg_crd_transform = GlobalInternalCoordinateTransformation(cg_z_matrix)
    cg_shape_info = ShapeDictionary.from_coordinate_transform(cg_crd_transform)
    CG_BONDS = cg_shape_info.replace(BONDS, "CG_BONDS")
    CG_ANGLES = cg_shape_info.replace(ANGLES, "CG_ANGLES")
    CG_TORSIONS = cg_shape_info.replace(TORSIONS, "CG_TORSIONS")
    shape_info.update(cg_shape_info)
    del shape_info[FIXED]

    # factory
    #marginals = InternalCoordinateMarginals(builder.current_dims)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(2):
        builder.add_condition(CG_TORSIONS, on=(CG_ANGLES, CG_BONDS))
        builder.add_condition((CG_ANGLES, CG_BONDS), on=CG_TORSIONS)
    marginals = InternalCoordinateMarginals(builder.current_dims, builder.ctx, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS)
    builder.add_map_to_ic_domains(marginals)
    builder.add_map_to_cartesian(cg_crd_transform, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS, out=FIXED)
    builder.transformer_type[FIXED] = bg.AffineTransformer
    for i in range(2):
        builder.add_condition(TORSIONS, on=FIXED)
        builder.add_condition(FIXED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, FIXED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    generator.energy(samples)
    generator.kldiv(10)
Exemplo n.º 5
0
def test_transformers(crd_trafo, transformer_type):
    pytest.importorskip("nflows")

    shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo)
    conditioners = make_conditioners(transformer_type, (BONDS, ), (FIXED, ),
                                     shape_info)
    transformer = make_transformer(transformer_type, (BONDS, ),
                                   shape_info,
                                   conditioners=conditioners)
    out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]),
                              torch.zeros(2, shape_info[BONDS][0]))
    assert out[0].shape == (2, shape_info[BONDS][0])
Exemplo n.º 6
0
def test_constrain_chirality(ala2, ctx):
    bgmol = pytest.importorskip("bgmol")
    top = ala2.system.mdtraj_topology
    zmatrix, _ = bgmol.ZMatrixFactory(top).build_naive()
    crd_transform = GlobalInternalCoordinateTransformation(zmatrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    chiral_torsions = bgmol.is_chiral_torsion(crd_transform.torsion_indices, top)
    builder.add_constrain_chirality(chiral_torsions)
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(20)
    b, a, t, *_ = crd_transform.forward(samples)
    assert torch.all(t[:, chiral_torsions] >= 0.5)
    assert torch.all(t[:, chiral_torsions] <= 1.0)
Exemplo n.º 7
0
def test_conditioner_factory_spline(crd_trafo):
    crd_transform = crd_trafo
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    # non-periodic
    conditioners = make_conditioners(ConditionalSplineTransformer, (BONDS, ),
                                     (ANGLES, ), shape_info)
    assert (conditioners["params_net"]._layers[-1].bias.shape == (
        (3 * 8 + 1) * shape_info[BONDS][0], ))
    # periodic
    conditioners = make_conditioners(ConditionalSplineTransformer,
                                     (TORSIONS, ), (ANGLES, ), shape_info)
    assert (conditioners["params_net"]._layers[-1].bias.shape == (
        (3 * 8) * shape_info[TORSIONS][0], ))
    # mixed
    conditioners = make_conditioners(ConditionalSplineTransformer,
                                     (BONDS, TORSIONS), (ANGLES, FIXED),
                                     shape_info)
    assert (conditioners["params_net"]._layers[-1].bias.shape == (
        (3 * 8) * (shape_info[BONDS][0] + shape_info[TORSIONS][0]) +
        shape_info[BONDS][0], ))
Exemplo n.º 8
0
def test_builder_augmentation_and_global(ala2, ctx):
    pytest.importorskip("nflows")

    crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform, dim_augmented=10)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(4):
        builder.add_condition(TORSIONS, on=AUGMENTED)
        builder.add_condition(AUGMENTED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, AUGMENTED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, AUGMENTED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    assert len(samples) == 2
    generator.energy(*samples)
    generator.kldiv(10)
Exemplo n.º 9
0
def test_conditioner_factory_input_dim(transformer_type, crd_trafo):
    torch.manual_seed(10981)

    crd_transform = crd_trafo
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    # check input dimensions:
    conditioners = make_conditioners(transformer_type, (BONDS, ), (FIXED, ),
                                     shape_info,
                                     hidden=(128, 128))
    for conditioner in conditioners.values():
        assert conditioner._layers[0].weight.shape == (128,
                                                       shape_info[FIXED][0])

    # check input dimensions of wrapped:
    conditioners = make_conditioners(transformer_type, (BONDS, ),
                                     (ANGLES, TORSIONS),
                                     shape_info,
                                     hidden=(128, 128))
    for conditioner in conditioners.values():
        assert conditioner.net._layers[0].weight.shape == (
            128, shape_info[ANGLES][0] + 2 * shape_info[TORSIONS][0])

    # check periodicity
    for conditioner in conditioners.values():
        for p in conditioner.parameters():
            p.data = torch.randn_like(p.data)
        # check torsions periodic
        low = conditioner(
            torch.zeros(shape_info[ANGLES][0] + shape_info[TORSIONS][0]))
        x = torch.cat([
            torch.zeros(shape_info[ANGLES][0]),
            torch.ones(shape_info[TORSIONS][0])
        ])
        high = conditioner(x)
        assert torch.allclose(low, high, atol=5e-4)
        # check angles not periodic
        x[0] = 1.0
        high = conditioner(x)
        assert not torch.allclose(low, high, atol=5e-2)
Exemplo n.º 10
0
def test_circular_affine(crd_trafo):
    shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo)

    with pytest.raises(ValueError):
        conditioners = make_conditioners(bgflow.AffineTransformer,
                                         (TORSIONS, ), (FIXED, ),
                                         shape_info=shape_info)
        make_transformer(bgflow.AffineTransformer, (TORSIONS, ),
                         shape_info,
                         conditioners=conditioners)

    conditioners = make_conditioners(bgflow.AffineTransformer, (TORSIONS, ),
                                     (FIXED, ),
                                     shape_info=shape_info,
                                     use_scaling=False)
    assert list(conditioners.keys()) == ["shift_transformation"]
    transformer = make_transformer(bgflow.AffineTransformer, (TORSIONS, ),
                                   shape_info,
                                   conditioners=conditioners)
    assert transformer._is_circular
    out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]),
                              torch.zeros(2, shape_info[TORSIONS][0]))
    assert out[0].shape == (2, shape_info[TORSIONS][0])
Exemplo n.º 11
0
def test_builder_api(ala2, ctx):
    pytest.importorskip("nflows")

    z_matrix = ala2.system.z_matrix
    fixed_atoms = ala2.system.rigid_block
    crd_transform = MixedCoordinateTransformation(torch.tensor(ala2.xyz, **ctx), z_matrix, fixed_atoms)
    shape_info = ShapeDictionary.from_coordinate_transform(crd_transform)
    builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx)
    for i in range(4):
        builder.add_condition(TORSIONS, on=FIXED)
        builder.add_condition(FIXED, on=TORSIONS)
    for i in range(2):
        builder.add_condition(BONDS, on=ANGLES)
        builder.add_condition(ANGLES, on=BONDS)
    builder.add_condition(ANGLES, on=(TORSIONS, FIXED))
    builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED))
    builder.add_map_to_ic_domains()
    builder.add_map_to_cartesian(crd_transform)
    generator = builder.build_generator()
    # play forward and backward
    samples = generator.sample(2)
    generator.energy(samples)
    generator.kldiv(10)
Exemplo n.º 12
0
def test_builder_split_merge(ctx):
    pytest.importorskip("nflows")
    shape_info = ShapeDictionary()
    shape_info[BONDS] = (10, )
    shape_info[ANGLES] = (20, )
    shape_info[TORSIONS] = (13, )
    builder = BoltzmannGeneratorBuilder(shape_info, **ctx)
    split1 = TensorInfo("SPLIT_1")
    split2 = TensorInfo("SPLIT_2")
    split3 = TensorInfo("SPLIT_3")
    builder.add_split(ANGLES, (split1, split2, split3), (6, 2, 12))
    builder.add_condition(split1, on=split2)
    generator = builder.build_generator(zero_parameters=True, check_target=False)
    samples = generator.sample(11)
    assert len(samples) == 5
    assert all(samples[i].shape == (11,j) for i, j in enumerate([10,6,2,12,13]))

    # check split + add_merge (with string arguments)
    assert builder.layers == []
    s1, split_2, s3 = builder.add_split(ANGLES, (split1, "split2", split3), (6, 2, 12))
    assert s1 == split1
    assert s3 == split3
    assert split_2.name == "split2"
    assert split_2.is_circular == ANGLES.is_circular
    builder.add_condition(split1, on=split_2)
    angles = builder.add_merge((split1, split_2, split3), "angles")
    assert angles.name == "angles"
    assert angles.is_circular == ANGLES.is_circular
    assert list(builder.current_dims) == [BONDS, angles, TORSIONS]
    generator = builder.build_generator(zero_parameters=True, check_target=False)
    samples = generator._prior.sample(11)
    assert all(torch.all(s > torch.zeros_like(s)) for s in samples)
    assert all(torch.all(s < torch.ones_like(s)) for s in samples)
    *output, dlogp = generator._flow.forward(*samples)
    assert all(s.shape == o.shape for s, o in zip(samples, output))
    assert all(torch.allclose(s, o, atol=0.01, rtol=0.0) for s, o in zip(samples, output))