def _log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, params: Optional[flowtorch.ParamsModule], context: torch.Tensor, ) -> torch.Tensor: assert isinstance(params, flowtorch.ParamsModule) # Note: params will take care of caching "mean, log_scale, perm = params(x)" _, log_scale = params(x, context=context) log_scale = clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) return log_scale.sum(-1)
def _forward( self, x: torch.Tensor, params: Optional[flowtorch.ParamsModule], context: torch.Tensor, ) -> torch.Tensor: assert isinstance(params, flowtorch.ParamsModule) mean, log_scale = params(x, context=context) log_scale = clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) scale = torch.exp(log_scale) y = scale * x + mean return y
def _inverse( self, y: torch.Tensor, params: Optional[flowtorch.ParamsModule], context: torch.Tensor, ) -> torch.Tensor: assert isinstance(params, flowtorch.ParamsModule) x = torch.zeros_like(y) # NOTE: Inversion is an expensive operation that scales in the # dimension of the input for idx in params.permutation: # type: ignore mean, log_scale = params(x.clone(), context=context) inverse_scale = torch.exp( -clamp_preserve_gradients( log_scale[..., idx], min=self.log_scale_min_clip, max=self.log_scale_max_clip, ) ) # * 10 mean = mean[..., idx] x[..., idx] = (y[..., idx] - mean) * inverse_scale return x
# Settings torch.manual_seed(0) batch_dim = 100 input_dim = 10 # Create stateless bijector and stateful hypernetwork base_dist = torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)) bijection = bijectors.AffineAutoregressive() lazy_params = params.DenseAutoregressive(hidden_dims=[50]) params = lazy_params(torch.Size([input_dim]), bijection.param_shapes(base_dist)) x = base_dist.rsample(torch.Size([batch_dim])) means, log_sds = params(x) print(means.shape, log_sds.shape, params.permutation) # Try out low-level methods of bijector x = torch.randn(input_dim) y = bijection.forward(x, params=params) y_inv = bijection.inverse(y, params=params) print(bijection) # <= testing inheritance from flowtorch.Bijector print("x", x) print("y", y) print("inv(y)", y_inv) # Example of lazily instantiating hypernetwork # TODO: Remove layer of indirection from the following (possibly with class decorator)!
import matplotlib.pyplot as plt import seaborn as sns # Settings #torch.manual_seed(0) batch_dim = 100000 input_dim = 128 # Create non-lazy parameters base_dist = torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)) bijection = bijectors.AffineAutoregressive() lazy_params = params.DenseAutoregressive(hidden_dims=[256,256,256,256,256,256,256]) #, permutation=torch.Tensor([0, 1, 2, 3])) params = lazy_params(torch.Size([input_dim]), bijection.param_shapes(base_dist)) x = base_dist.rsample(torch.Size([batch_dim])) mean, log_scale = [y.detach().numpy() for y in params(x)] print(mean.shape, log_scale.shape) #print(mean[:10,0]) #print(mean[:10,1]) print(mean[:,1].mean(), mean[:,1].std()) #plt.plot(mean[:,0], mean[:,1], 'o', color='blue', alpha=0.7, label='mean') sns.distplot(mean[:,1], hist = False, kde = True, kde_kws = {'linewidth': 3}, label = 'mean') #plt.plot(z_base[:,0], z_base[:,1], 'o', color='red', alpha=0.7, label='base') plt.title('Samples from MADE') #plt.xlabel('$x_1$') #plt.ylabel('$x_2$') plt.legend() plt.show()