def _log_abs_det_jacobian(
     self,
     x: torch.Tensor,
     y: torch.Tensor,
     params: Optional[flowtorch.ParamsModule],
     context: torch.Tensor,
 ) -> torch.Tensor:
     assert isinstance(params, flowtorch.ParamsModule)
     # Note: params will take care of caching "mean, log_scale, perm = params(x)"
     _, log_scale = params(x, context=context)
     log_scale = clamp_preserve_gradients(
         log_scale, self.log_scale_min_clip, self.log_scale_max_clip
     )
     return log_scale.sum(-1)
 def _forward(
     self,
     x: torch.Tensor,
     params: Optional[flowtorch.ParamsModule],
     context: torch.Tensor,
 ) -> torch.Tensor:
     assert isinstance(params, flowtorch.ParamsModule)
     mean, log_scale = params(x, context=context)
     log_scale = clamp_preserve_gradients(
         log_scale, self.log_scale_min_clip, self.log_scale_max_clip
     )
     scale = torch.exp(log_scale)
     y = scale * x + mean
     return y
    def _inverse(
        self,
        y: torch.Tensor,
        params: Optional[flowtorch.ParamsModule],
        context: torch.Tensor,
    ) -> torch.Tensor:
        assert isinstance(params, flowtorch.ParamsModule)
        x = torch.zeros_like(y)

        # NOTE: Inversion is an expensive operation that scales in the
        # dimension of the input
        for idx in params.permutation:  # type: ignore
            mean, log_scale = params(x.clone(), context=context)
            inverse_scale = torch.exp(
                -clamp_preserve_gradients(
                    log_scale[..., idx],
                    min=self.log_scale_min_clip,
                    max=self.log_scale_max_clip,
                )
            )  # * 10
            mean = mean[..., idx]
            x[..., idx] = (y[..., idx] - mean) * inverse_scale

        return x
Esempio n. 4
0
# Settings
torch.manual_seed(0)
batch_dim = 100
input_dim = 10

# Create stateless bijector and stateful hypernetwork
base_dist = torch.distributions.Normal(torch.zeros(input_dim),
                                       torch.ones(input_dim))
bijection = bijectors.AffineAutoregressive()
lazy_params = params.DenseAutoregressive(hidden_dims=[50])
params = lazy_params(torch.Size([input_dim]),
                     bijection.param_shapes(base_dist))

x = base_dist.rsample(torch.Size([batch_dim]))
means, log_sds = params(x)

print(means.shape, log_sds.shape, params.permutation)

# Try out low-level methods of bijector
x = torch.randn(input_dim)
y = bijection.forward(x, params=params)
y_inv = bijection.inverse(y, params=params)

print(bijection)  # <= testing inheritance from flowtorch.Bijector
print("x", x)
print("y", y)
print("inv(y)", y_inv)

# Example of lazily instantiating hypernetwork
# TODO: Remove layer of indirection from the following (possibly with class decorator)!
import matplotlib.pyplot as plt
import seaborn as sns

# Settings
#torch.manual_seed(0)
batch_dim = 100000
input_dim = 128

# Create non-lazy parameters
base_dist = torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim))
bijection = bijectors.AffineAutoregressive()
lazy_params = params.DenseAutoregressive(hidden_dims=[256,256,256,256,256,256,256]) #, permutation=torch.Tensor([0, 1, 2, 3]))
params = lazy_params(torch.Size([input_dim]), bijection.param_shapes(base_dist))

x = base_dist.rsample(torch.Size([batch_dim]))
mean, log_scale = [y.detach().numpy() for y in params(x)]

print(mean.shape, log_scale.shape)
#print(mean[:10,0])
#print(mean[:10,1])

print(mean[:,1].mean(), mean[:,1].std())

#plt.plot(mean[:,0], mean[:,1], 'o', color='blue', alpha=0.7, label='mean')
sns.distplot(mean[:,1], hist = False, kde = True, kde_kws = {'linewidth': 3}, label = 'mean')
#plt.plot(z_base[:,0], z_base[:,1], 'o', color='red', alpha=0.7, label='base')
plt.title('Samples from MADE')
#plt.xlabel('$x_1$')
#plt.ylabel('$x_2$')
plt.legend()
plt.show()