def test_readme(): import torch import matplotlib.pyplot as plt import bgflow as bg # define prior and target dim = 2 prior = bg.NormalDistribution(dim) target = bg.DoubleWellEnergy(dim) # here we aggregate all layers of the flow layers = [] layers.append(bg.SplitFlow(dim // 2)) layers.append( bg.CouplingFlow( # we use a affine transformation to transform the RHS conditioned on the LHS bg.AffineTransformer( # use simple dense nets for the affine shift/scale shift_transformation=bg.DenseNet([dim // 2, 4, dim // 2], activation=torch.nn.ReLU()), scale_transformation=bg.DenseNet([dim // 2, 4, dim // 2], activation=torch.nn.Tanh())))) layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2))) # now define the flow as a sequence of all operations stored in layers flow = bg.SequentialFlow(layers) # The BG is defined by a prior, target and a flow generator = bg.BoltzmannGenerator(prior, flow, target) # sample from the BG samples = generator.sample(1000) plt.hist2d(samples[:, 0].detach().numpy(), samples[:, 1].detach().numpy(), bins=100)
def _coupling_block(self, dim1, dim2): return bg.CouplingFlow( bg.AffineTransformer( shift_transformation=self._dense_net(dim1, dim2), scale_transformation=self._dense_net(dim1, dim2)))
import bgflow as bg # define prior and target dim = 2 prior = bg.NormalDistribution(dim) target = bg.DoubleWellEnergy(dim) # here we aggregate all layers of the flow layers = [] layers.append(bg.SplitFlow(dim // 2)) layers.append( bg.CouplingFlow( # we use a affine transformation to transform # the RHS conditioned on the LHS bg.AffineTransformer( # use simple dense nets for the affine shift/scale shift_transformation=bg.DenseNet([dim // 2, 4, dim // 2], activation=torch.nn.ReLU()), scale_transformation=bg.DenseNet([dim // 2, 4, dim // 2], activation=torch.nn.Tanh())))) layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2))) # now define the flow as a sequence of all operations stored in layers flow = bg.SequentialFlow(layers) # The BG is defined by a prior, target and a flow generator = bg.BoltzmannGenerator(prior, flow, target) # sample from the BG samples = generator.sample(1000) _ = plt.hist2d(samples[:, 0].detach().numpy(),
def test_bg_basic_multiple(device, dtype): dim = 4 mean = torch.zeros(dim // 2, dtype=dtype, device=device) import bgflow as bg prior = bg.ProductDistribution([ bg.NormalDistribution(dim // 2, mean), bg.NormalDistribution(dim // 2, mean) ]) # RealNVP flow = bg.SequentialFlow([ bg.CouplingFlow( bg.AffineTransformer(bg.DenseNet([dim // 2, dim, dim // 2]), bg.DenseNet([dim // 2, dim, dim // 2]))), bg.SwapFlow(), bg.CouplingFlow( bg.AffineTransformer(bg.DenseNet([dim // 2, dim, dim // 2]), bg.DenseNet([dim // 2, dim, dim // 2]))), bg.SwapFlow(), ]).to(mean) target = bg.ProductDistribution([ bg.NormalDistribution(dim // 2, mean), bg.NormalDistribution(dim // 2, mean) ]) generator = bg.BoltzmannGenerator(prior, flow, target) # set parameters to 0 -> flow = id for p in generator.parameters(): p.data.zero_() z = prior.sample(10) *x, dlogp = flow.forward(*z) for zi, xi in zip(z, x): assert torch.allclose(zi, xi) assert torch.allclose(dlogp, torch.zeros_like(dlogp)) # Test losses generator.zero_grad() kll = generator.kldiv(100000) kll.mean().backward() # gradients should be small, as the network is already optimal for p in generator.parameters(): assert torch.allclose(p.grad, torch.zeros_like(p.grad), rtol=0.0, atol=5e-2) generator.zero_grad() samples = target.sample(100000) nll = generator.energy(*samples) nll.mean().backward() # gradients should be small, as the network is already optimal for p in generator.parameters(): assert torch.allclose(p.grad, torch.zeros_like(p.grad), rtol=0.0, atol=5e-2) # just testing the API for the following: generator.log_weights(*samples) *z, dlogp = flow.forward(*samples, inverse=True) generator.log_weights_given_latent(samples, z, dlogp) generator.sample(10000) generator.force(*z) # test trainers trainer = bg.KLTrainer(generator) sampler = bg.ProductSampler( [bg.DataSetSampler(samples[0]), bg.DataSetSampler(samples[1])]).to(device=device, dtype=dtype) trainer.train(100, sampler)