def test_builder_add_layer_and_param_groups(ctx): shape_info = ShapeDictionary() shape_info[BONDS] = (10, ) shape_info[ANGLES] = (20, ) builder = BoltzmannGeneratorBuilder(shape_info, **ctx) # transform some fields builder.add_layer( CDFTransform( TruncatedNormalDistribution(torch.zeros(10, **ctx), lower_bound=-torch.tensor(np.infty)), ), what=[BONDS], inverse=True, param_groups=("group1", ) ) # transform all fields builder.add_layer( CouplingFlow( AffineTransformer( DenseNet([10, 20]), DenseNet([10, 20]) ) ), param_groups=("group1", "group2") ) builder.targets[BONDS] = NormalDistribution(10, torch.zeros(10, **ctx)) builder.targets[ANGLES] = NormalDistribution(20, torch.zeros(20, **ctx)) generator = builder.build_generator().to(**ctx) assert builder.param_groups["group1"] == list(generator.parameters()) assert builder.param_groups["group2"] == list(generator._flow._blocks[1].parameters()) generator.sample(10) generator.kldiv(10)
def test_builder_bond_constraints(ala2, ctx): # import logging # logger = logging.getLogger('bgflow') # logger.setLevel(logging.DEBUG) # logger.addHandler( # logging.StreamHandler() # ) pytest.importorskip("nflows") crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix) shape_info = ShapeDictionary.from_coordinate_transform( crd_transform, dim_augmented=0, n_constraints=2, remove_origin_and_rotation=True ) builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx) constrained_bond_indices = [0, 1] constrained_bond_lengths = [0.1, 0.1] assert builder.current_dims[BONDS] == (19, ) assert builder.prior_dims[BONDS] == (19, ) builder.add_condition(BONDS, on=(ANGLES, TORSIONS)) builder.add_map_to_ic_domains() builder.add_merge_constraints(constrained_bond_indices, constrained_bond_lengths) assert builder.current_dims[BONDS] == (21, ) builder.add_map_to_cartesian(crd_transform) generator = builder.build_generator() # play forward and backward samples = generator.sample(2) assert samples.shape == (2, 66) generator.energy(samples) generator.kldiv(10)
def test_icmarginals_inform_api(tmpdir, ctx, with_data): """API test""" bgmol = pytest.importorskip("bgmol") dataset = bgmol.datasets.Ala2Implicit1000Test( root=tmpdir, download=True, read=True ) coordinate_transform = GlobalInternalCoordinateTransformation( bgmol.systems.ala2.DEFAULT_GLOBAL_Z_MATRIX ) current_dims = ShapeDictionary() current_dims[BONDS] = (coordinate_transform.dim_bonds - dataset.system.system.getNumConstraints(), ) current_dims[ANGLES] = (coordinate_transform.dim_angles, ) marginals = InternalCoordinateMarginals(current_dims, ctx) if with_data: constrained_indices, _ = bgmol.bond_constraints(dataset.system.system, coordinate_transform) marginals.inform_with_data( torch.tensor(dataset.xyz, **ctx), coordinate_transform, constrained_bond_indices=constrained_indices ) else: marginals.inform_with_force_field( dataset.system.system, coordinate_transform, 1000., )
def test_builder_multiple_crd(ala2, ctx): bgmol = pytest.importorskip("bgmol") pytest.importorskip("nflows") # all-atom trafo z_matrix, fixed = bgmol.ZMatrixFactory(ala2.system.mdtraj_topology, cartesian=[6, 8, 10, 14, 16]).build_naive() crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed) shape_info = ShapeDictionary.from_coordinate_transform(crd_transform) # cg trafo cg_top, _ = bgmol.build_fake_topology(5) cg_z_matrix, _ = bgmol.ZMatrixFactory(cg_top).build_naive() cg_crd_transform = GlobalInternalCoordinateTransformation(cg_z_matrix) cg_shape_info = ShapeDictionary.from_coordinate_transform(cg_crd_transform) CG_BONDS = cg_shape_info.replace(BONDS, "CG_BONDS") CG_ANGLES = cg_shape_info.replace(ANGLES, "CG_ANGLES") CG_TORSIONS = cg_shape_info.replace(TORSIONS, "CG_TORSIONS") shape_info.update(cg_shape_info) del shape_info[FIXED] # factory #marginals = InternalCoordinateMarginals(builder.current_dims) builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx) for i in range(2): builder.add_condition(CG_TORSIONS, on=(CG_ANGLES, CG_BONDS)) builder.add_condition((CG_ANGLES, CG_BONDS), on=CG_TORSIONS) marginals = InternalCoordinateMarginals(builder.current_dims, builder.ctx, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS) builder.add_map_to_ic_domains(marginals) builder.add_map_to_cartesian(cg_crd_transform, bonds=CG_BONDS, angles=CG_ANGLES, torsions=CG_TORSIONS, out=FIXED) builder.transformer_type[FIXED] = bg.AffineTransformer for i in range(2): builder.add_condition(TORSIONS, on=FIXED) builder.add_condition(FIXED, on=TORSIONS) for i in range(2): builder.add_condition(BONDS, on=ANGLES) builder.add_condition(ANGLES, on=BONDS) builder.add_condition(ANGLES, on=(TORSIONS, FIXED)) builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED)) builder.add_map_to_ic_domains() builder.add_map_to_cartesian(crd_transform) generator = builder.build_generator() # play forward and backward samples = generator.sample(2) generator.energy(samples) generator.kldiv(10)
def test_transformers(crd_trafo, transformer_type): pytest.importorskip("nflows") shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo) conditioners = make_conditioners(transformer_type, (BONDS, ), (FIXED, ), shape_info) transformer = make_transformer(transformer_type, (BONDS, ), shape_info, conditioners=conditioners) out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[BONDS][0])) assert out[0].shape == (2, shape_info[BONDS][0])
def test_constrain_chirality(ala2, ctx): bgmol = pytest.importorskip("bgmol") top = ala2.system.mdtraj_topology zmatrix, _ = bgmol.ZMatrixFactory(top).build_naive() crd_transform = GlobalInternalCoordinateTransformation(zmatrix) shape_info = ShapeDictionary.from_coordinate_transform(crd_transform) builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx) chiral_torsions = bgmol.is_chiral_torsion(crd_transform.torsion_indices, top) builder.add_constrain_chirality(chiral_torsions) builder.add_map_to_ic_domains() builder.add_map_to_cartesian(crd_transform) generator = builder.build_generator() # play forward and backward samples = generator.sample(20) b, a, t, *_ = crd_transform.forward(samples) assert torch.all(t[:, chiral_torsions] >= 0.5) assert torch.all(t[:, chiral_torsions] <= 1.0)
def test_conditioner_factory_spline(crd_trafo): crd_transform = crd_trafo shape_info = ShapeDictionary.from_coordinate_transform(crd_transform) # non-periodic conditioners = make_conditioners(ConditionalSplineTransformer, (BONDS, ), (ANGLES, ), shape_info) assert (conditioners["params_net"]._layers[-1].bias.shape == ( (3 * 8 + 1) * shape_info[BONDS][0], )) # periodic conditioners = make_conditioners(ConditionalSplineTransformer, (TORSIONS, ), (ANGLES, ), shape_info) assert (conditioners["params_net"]._layers[-1].bias.shape == ( (3 * 8) * shape_info[TORSIONS][0], )) # mixed conditioners = make_conditioners(ConditionalSplineTransformer, (BONDS, TORSIONS), (ANGLES, FIXED), shape_info) assert (conditioners["params_net"]._layers[-1].bias.shape == ( (3 * 8) * (shape_info[BONDS][0] + shape_info[TORSIONS][0]) + shape_info[BONDS][0], ))
def test_builder_augmentation_and_global(ala2, ctx): pytest.importorskip("nflows") crd_transform = GlobalInternalCoordinateTransformation(ala2.system.global_z_matrix) shape_info = ShapeDictionary.from_coordinate_transform(crd_transform, dim_augmented=10) builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx) for i in range(4): builder.add_condition(TORSIONS, on=AUGMENTED) builder.add_condition(AUGMENTED, on=TORSIONS) for i in range(2): builder.add_condition(BONDS, on=ANGLES) builder.add_condition(ANGLES, on=BONDS) builder.add_condition(ANGLES, on=(TORSIONS, AUGMENTED)) builder.add_condition(BONDS, on=(ANGLES, TORSIONS, AUGMENTED)) builder.add_map_to_ic_domains() builder.add_map_to_cartesian(crd_transform) generator = builder.build_generator() # play forward and backward samples = generator.sample(2) assert len(samples) == 2 generator.energy(*samples) generator.kldiv(10)
def test_conditioner_factory_input_dim(transformer_type, crd_trafo): torch.manual_seed(10981) crd_transform = crd_trafo shape_info = ShapeDictionary.from_coordinate_transform(crd_transform) # check input dimensions: conditioners = make_conditioners(transformer_type, (BONDS, ), (FIXED, ), shape_info, hidden=(128, 128)) for conditioner in conditioners.values(): assert conditioner._layers[0].weight.shape == (128, shape_info[FIXED][0]) # check input dimensions of wrapped: conditioners = make_conditioners(transformer_type, (BONDS, ), (ANGLES, TORSIONS), shape_info, hidden=(128, 128)) for conditioner in conditioners.values(): assert conditioner.net._layers[0].weight.shape == ( 128, shape_info[ANGLES][0] + 2 * shape_info[TORSIONS][0]) # check periodicity for conditioner in conditioners.values(): for p in conditioner.parameters(): p.data = torch.randn_like(p.data) # check torsions periodic low = conditioner( torch.zeros(shape_info[ANGLES][0] + shape_info[TORSIONS][0])) x = torch.cat([ torch.zeros(shape_info[ANGLES][0]), torch.ones(shape_info[TORSIONS][0]) ]) high = conditioner(x) assert torch.allclose(low, high, atol=5e-4) # check angles not periodic x[0] = 1.0 high = conditioner(x) assert not torch.allclose(low, high, atol=5e-2)
def test_circular_affine(crd_trafo): shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo) with pytest.raises(ValueError): conditioners = make_conditioners(bgflow.AffineTransformer, (TORSIONS, ), (FIXED, ), shape_info=shape_info) make_transformer(bgflow.AffineTransformer, (TORSIONS, ), shape_info, conditioners=conditioners) conditioners = make_conditioners(bgflow.AffineTransformer, (TORSIONS, ), (FIXED, ), shape_info=shape_info, use_scaling=False) assert list(conditioners.keys()) == ["shift_transformation"] transformer = make_transformer(bgflow.AffineTransformer, (TORSIONS, ), shape_info, conditioners=conditioners) assert transformer._is_circular out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[TORSIONS][0])) assert out[0].shape == (2, shape_info[TORSIONS][0])
def test_builder_api(ala2, ctx): pytest.importorskip("nflows") z_matrix = ala2.system.z_matrix fixed_atoms = ala2.system.rigid_block crd_transform = MixedCoordinateTransformation(torch.tensor(ala2.xyz, **ctx), z_matrix, fixed_atoms) shape_info = ShapeDictionary.from_coordinate_transform(crd_transform) builder = BoltzmannGeneratorBuilder(shape_info, target=ala2.system.energy_model, **ctx) for i in range(4): builder.add_condition(TORSIONS, on=FIXED) builder.add_condition(FIXED, on=TORSIONS) for i in range(2): builder.add_condition(BONDS, on=ANGLES) builder.add_condition(ANGLES, on=BONDS) builder.add_condition(ANGLES, on=(TORSIONS, FIXED)) builder.add_condition(BONDS, on=(ANGLES, TORSIONS, FIXED)) builder.add_map_to_ic_domains() builder.add_map_to_cartesian(crd_transform) generator = builder.build_generator() # play forward and backward samples = generator.sample(2) generator.energy(samples) generator.kldiv(10)
def test_builder_split_merge(ctx): pytest.importorskip("nflows") shape_info = ShapeDictionary() shape_info[BONDS] = (10, ) shape_info[ANGLES] = (20, ) shape_info[TORSIONS] = (13, ) builder = BoltzmannGeneratorBuilder(shape_info, **ctx) split1 = TensorInfo("SPLIT_1") split2 = TensorInfo("SPLIT_2") split3 = TensorInfo("SPLIT_3") builder.add_split(ANGLES, (split1, split2, split3), (6, 2, 12)) builder.add_condition(split1, on=split2) generator = builder.build_generator(zero_parameters=True, check_target=False) samples = generator.sample(11) assert len(samples) == 5 assert all(samples[i].shape == (11,j) for i, j in enumerate([10,6,2,12,13])) # check split + add_merge (with string arguments) assert builder.layers == [] s1, split_2, s3 = builder.add_split(ANGLES, (split1, "split2", split3), (6, 2, 12)) assert s1 == split1 assert s3 == split3 assert split_2.name == "split2" assert split_2.is_circular == ANGLES.is_circular builder.add_condition(split1, on=split_2) angles = builder.add_merge((split1, split_2, split3), "angles") assert angles.name == "angles" assert angles.is_circular == ANGLES.is_circular assert list(builder.current_dims) == [BONDS, angles, TORSIONS] generator = builder.build_generator(zero_parameters=True, check_target=False) samples = generator._prior.sample(11) assert all(torch.all(s > torch.zeros_like(s)) for s in samples) assert all(torch.all(s < torch.ones_like(s)) for s in samples) *output, dlogp = generator._flow.forward(*samples) assert all(s.shape == o.shape for s, o in zip(samples, output)) assert all(torch.allclose(s, o, atol=0.01, rtol=0.0) for s, o in zip(samples, output))