def _get_transformed_x_dist(self, latent, ctx=None):
        x_pred_dist = self.decoder.predict(
            latent,
            ctx)  # returns a normal dist with mean of the predicted image
        if self.laplace_likelihood:
            x_base_dist = Laplace(self.x_base_loc,
                                  self.x_base_scale).to_event(3)
        else:
            x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(
                3)  # 3 dimensions starting from right dep.

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)
        else:
            raise ValueError(f'{x_pred_dist} not valid.')

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
    def generate(self, x, num_particles):
        z_dist = self.encoder.predict(x)
        z = z_dist.sample()
        x_pred_dist = self.decoder.predict(z)

        x_base_dist = dist.Normal(
            torch.zeros_like(x, requires_grad=False).view(x.shape[0], -1),
            torch.ones_like(x, requires_grad=False).view(x.shape[0], -1),
        ).to_event(1)

        if 'normal' in self.decoder_output or \
            self.decoder_output == 'deepvar' or \
            self.decoder_output == 'deepmean':
            transform = AffineTransform(x_pred_dist.mean, x_pred_dist.stddev,
                                        1)
        elif self.decoder_output == 'low_rank_mvn':
            # print(x_pred_dist.loc.shape)
            # print(x_pred_dist.loc)
            # print(x_pred_dist.scale_tril.shape)
            # print(x_pred_dist.scale_tril)
            transform = LowerCholeskyAffine(x_pred_dist.loc,
                                            x_pred_dist.scale_tril)
        else:
            raise Exception('Unknown decoder output')

        x_dist = dist.TransformedDistribution(x_base_dist,
                                              ComposeTransform([transform]))

        recons = []
        for i in range(num_particles):
            recon = pyro.sample('x', x_dist).view(x.shape[0], self.shape, 3)
            recons.append(recon)
        return torch.stack(recons).mean(0)
Пример #3
0
 def _get_preprocess_transforms(self):
     alpha = 0.05
     num_bits = 8
     if self.preprocessing == 'glow':
         # Map to [-0.5,0.5]
         a1 = AffineTransform(-0.5, (1. / 2 ** num_bits))
         preprocess_transform = ComposeTransform([a1])
     elif self.preprocessing == 'realnvp':
         # Map to [0,1]
         a1 = AffineTransform(0., (1. / 2 ** num_bits))
         # Map into unconstrained space as done in RealNVP
         a2 = AffineTransform(alpha, (1 - alpha))
         s = SigmoidTransform()
         preprocess_transform = ComposeTransform([a1, a2, s.inv])
     else:
         raise ValueError(f'{self.preprocessing} not valid.')
     return preprocess_transform
Пример #4
0
 def model():
     fn = dist.TransformedDistribution(
         dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)),
         [AffineTransform(loc, scale),
          ExpTransform()])
     if event_shape:
         fn = fn.to_event(len(event_shape))
     with pyro.plate_stack("plates", batch_shape):
         with pyro.plate("particles", 200000):
             return pyro.sample("x", fn)
Пример #5
0
 def model():
     with pyro.plate_stack("plates", shape):
         with pyro.plate("particles", 200000):
             return pyro.sample(
                 "x",
                 dist.TransformedDistribution(
                     dist.Normal(torch.zeros_like(loc),
                                 torch.ones_like(scale)),
                     [AffineTransform(loc, scale),
                      ExpTransform()]))
def model(n_samples=None, scale=2.):
    with pyro.plate('observations', n_samples):
        thickness = pyro.sample('thickness', Gamma(10., 5.))

        loc = (thickness - 2.5) * 2

        transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)])

        width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms))

    return thickness, width
Пример #7
0
    def __init__(self, preprocessing: str = 'realnvp'):
        super().__init__()

        self.preprocessing = preprocessing

        self.register_buffer('thickness_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('thickness_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        self.register_buffer('intensity_flow_norm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('intensity_flow_norm_scale',
                             torch.ones([], requires_grad=False))

        self.thickness_flow_lognorm = AffineTransform(
            loc=self.thickness_flow_lognorm_loc.item(),
            scale=self.thickness_flow_lognorm_scale.item())
        self.intensity_flow_norm = AffineTransform(
            loc=self.intensity_flow_norm_loc.item(),
            scale=self.intensity_flow_norm_scale.item())
Пример #8
0
def model(n_samples=None, scale=0.5, invert=False):
    with pyro.plate('observations', n_samples):
        thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.))

        if invert:
            loc = (thickness - 2) * -2
        else:
            loc = (thickness - 2.5) * 2

        transforms = ComposeTransform(
            [SigmoidTransform(), AffineTransform(64, 191)])

        intensity = pyro.sample(
            'intensity', TransformedDistribution(Normal(loc, scale),
                                                 transforms))

    return thickness, intensity
Пример #9
0
    def _get_transformed_x_dist(self, latent):
        x_pred_dist = self.decoder.predict(latent)
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
Пример #10
0
    def model(self):
        age, sex, ventricle_volume, brain_volume = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)
        age_ = self.age_flow_constraint_transforms.inv(age)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1)

        x_loc = self.decoder_mean(self.decoder(latent))
        x_scale = torch.exp(self.decoder_logstd)

        theta = self.decoder_affine_param_net(latent).view(-1, 2, 3)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            grid = nn.functional.affine_grid(theta, x_loc.size())
            x_loc_deformed = nn.functional.grid_sample(x_loc, grid)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform([
                AffineTransform(x_loc_deformed, x_scale, 3),
                preprocess_transform
            ]))

        x = pyro.sample('x', x_dist)

        return x, z, age, sex, ventricle_volume, brain_volume
Пример #11
0
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # decoder parts
        self.decoder = Decoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + 3,
                               upconv=self.use_upconv)

        self.decoder_mean = torch.nn.Conv2d(1, 1, 1)
        self.decoder_logstd = torch.nn.Parameter(
            torch.ones([]) * self.logstd_init)

        self.decoder_affine_param_net = nn.Sequential(
            nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, 6))

        self.decoder_affine_param_net[-1].weight.data.zero_()
        self.decoder_affine_param_net[-1].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # ventricle_volume flow
        # TODO: decide on how many things to condition on
        ventricle_volume_net = DenseNN(2, [8, 16],
                                       param_dims=[1, 1],
                                       nonlinearity=torch.nn.Identity())
        self.ventricle_volume_flow_components = ConditionalAffineTransform(
            context_nn=ventricle_volume_net, event_dim=0)
        self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])
        self.ventricle_volume_flow_transforms = [
            self.ventricle_volume_flow_components,
            self.ventricle_volume_flow_constraint_transforms
        ]

        # brain_volume flow
        # TODO: decide on how many things to condition on
        brain_volume_net = DenseNN(2, [8, 16],
                                   param_dims=[1, 1],
                                   nonlinearity=torch.nn.Identity())
        self.brain_volume_flow_components = ConditionalAffineTransform(
            context_nn=brain_volume_net, event_dim=0)
        self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
        self.brain_volume_flow_transforms = [
            self.brain_volume_flow_components,
            self.brain_volume_flow_constraint_transforms
        ]

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim)

        # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + 3, self.latent_dim),
            torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)
Пример #12
0
    def __init__(self,
                 latent_dim: int,
                 logstd_init: float = -5,
                 enc_filters: str = '16,32,64,128',
                 dec_filters: str = '128,64,32,16',
                 num_convolutions: int = 2,
                 use_upconv: bool = False,
                 decoder_type: str = 'fixed_var',
                 decoder_cov_rank: int = 10,
                 **kwargs):
        super().__init__(**kwargs)

        self.img_shape = (1, 192 // self.downsample, 192 //
                          self.downsample) if self.downsample > 0 else (1, 192,
                                                                        192)

        self.latent_dim = latent_dim
        self.logstd_init = logstd_init

        self.enc_filters = tuple(
            int(f.strip()) for f in enc_filters.split(','))
        self.dec_filters = tuple(
            int(f.strip()) for f in dec_filters.split(','))
        self.num_convolutions = num_convolutions
        self.use_upconv = use_upconv
        self.decoder_type = decoder_type
        self.decoder_cov_rank = decoder_cov_rank

        # decoder parts
        decoder = Decoder(num_convolutions=self.num_convolutions,
                          filters=self.dec_filters,
                          latent_dim=self.latent_dim + self.context_dim,
                          upconv=self.use_upconv,
                          output_size=self.img_shape)

        if self.decoder_type == 'fixed_var':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = False

            torch.nn.init.constant_(self.decoder.logvar_head.bias,
                                    self.logstd_init)
            self.decoder.logvar_head.bias.requires_grad = False
        elif self.decoder_type == 'learned_var':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = False

            torch.nn.init.constant_(self.decoder.logvar_head.bias,
                                    self.logstd_init)
            self.decoder.logvar_head.bias.requires_grad = True
        elif self.decoder_type == 'independent_gaussian':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = True

            torch.nn.init.normal_(self.decoder.logvar_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logvar_head.bias.requires_grad = True
        elif self.decoder_type == 'multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape),
                                                  np.prod(self.img_shape))
        elif self.decoder_type == 'sharedvar_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape),
                                                  np.prod(self.img_shape))

            torch.nn.init.zeros_(self.decoder.logdiag_head.weight)
            self.decoder.logdiag_head.weight.requires_grad = False

            torch.nn.init.zeros_(self.decoder.lower_head.weight)
            self.decoder.lower_head.weight.requires_grad = False

            torch.nn.init.normal_(self.decoder.logdiag_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logdiag_head.bias.requires_grad = True
        elif self.decoder_type == 'lowrank_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepLowRankMultivariateNormal(
                seq, np.prod(self.img_shape), np.prod(self.img_shape),
                decoder_cov_rank)
        elif self.decoder_type == 'sharedvar_lowrank_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepLowRankMultivariateNormal(
                seq, np.prod(self.img_shape), np.prod(self.img_shape),
                decoder_cov_rank)

            torch.nn.init.zeros_(self.decoder.logdiag_head.weight)
            self.decoder.logdiag_head.weight.requires_grad = False

            torch.nn.init.zeros_(self.decoder.factor_head.weight)
            self.decoder.factor_head.weight.requires_grad = False

            torch.nn.init.normal_(self.decoder.logdiag_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logdiag_head.bias.requires_grad = True
        else:
            raise ValueError('unknown  ')

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim,
                               input_size=self.img_shape)

        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + self.context_dim,
                            self.latent_dim), torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)

        # priors
        self.register_buffer('age_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('age_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.sex_logits = torch.nn.Parameter(torch.zeros([
            1,
        ]))

        self.register_buffer('ventricle_volume_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('ventricle_volume_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer('brain_volume_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('brain_volume_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer('z_loc',
                             torch.zeros([
                                 latent_dim,
                             ], requires_grad=False))
        self.register_buffer('z_scale',
                             torch.ones([
                                 latent_dim,
                             ], requires_grad=False))

        self.register_buffer('x_base_loc',
                             torch.zeros(self.img_shape, requires_grad=False))
        self.register_buffer('x_base_scale',
                             torch.ones(self.img_shape, requires_grad=False))

        self.register_buffer('age_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('age_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        self.register_buffer('ventricle_volume_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('ventricle_volume_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        self.register_buffer('brain_volume_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('brain_volume_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(
            loc=self.age_flow_lognorm_loc.item(),
            scale=self.age_flow_lognorm_scale.item())
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # other flows shared components
        self.ventricle_volume_flow_lognorm = AffineTransform(
            loc=self.ventricle_volume_flow_lognorm_loc.item(),
            scale=self.ventricle_volume_flow_lognorm_scale.item(
            ))  # noqa: E501
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])

        self.brain_volume_flow_lognorm = AffineTransform(
            loc=self.brain_volume_flow_lognorm_loc.item(),
            scale=self.brain_volume_flow_lognorm_scale.item())
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
Пример #13
0
 def __neg__(self):
     return RandomVariable(
         TransformedDistribution(self.distribution, AffineTransform(0, -1)))
Пример #14
0
 def __truediv__(self, x: Union[float, Tensor]):
     return RandomVariable(
         TransformedDistribution(self.distribution,
                                 AffineTransform(0, 1 / x)))
Пример #15
0
 def __rsub__(self, x: Union[float, Tensor]):
     return RandomVariable(
         TransformedDistribution(self.distribution, AffineTransform(x, -1)))
    def __init__(self,
                 latent_dim: int,
                 prior_components: int = 1,
                 posterior_components: int = 1,
                 logstd_init: float = -5,
                 enc_filters: Tuple[int] = (16, 32, 64, 128),
                 dec_filters: Tuple[int] = (128, 64, 32, 16),
                 num_convolutions: int = 3,
                 use_upconv: bool = False,
                 decoder_type: str = 'fixed_var',
                 decoder_cov_rank: int = 10,
                 img_shape: Tuple[int] = (128, 128),
                 use_nvae=False,
                 use_weight_norm=False,
                 use_spectral_norm=False,
                 laplace_likelihood=False,
                 eps=0.1,
                 n_prior_flows=3,
                 n_posterior_flows=3,
                 use_autoregressive=False,
                 use_swish=False,
                 use_spline=False,
                 use_stable=False,
                 pseudo3d=False,
                 head_filters=(16, 16),
                 **kwargs):
        super().__init__(**kwargs)
        self.encoder_shape = ((3, ) if pseudo3d else (1, )) + tuple(img_shape)
        self.decoder_shape = (head_filters[0], ) + tuple(img_shape)
        self.img_shape = (1, ) + tuple(img_shape)
        self.latent_dim = latent_dim
        self.prior_components = prior_components
        self.posterior_components = posterior_components
        self.logstd_init = logstd_init
        self.enc_filters = enc_filters
        self.dec_filters = dec_filters
        self.head_filters = head_filters
        self.num_convolutions = num_convolutions
        self.use_upconv = use_upconv
        self.decoder_type = decoder_type
        self.decoder_cov_rank = decoder_cov_rank
        self.use_nvae = use_nvae
        self.use_weight_norm = use_weight_norm
        self.use_spectral_norm = use_spectral_norm
        self.laplace_likelihood = laplace_likelihood
        self.eps = eps
        self.n_prior_flows = n_prior_flows
        self.n_posterior_flows = n_posterior_flows
        self.use_autoregressive = use_autoregressive
        self.use_spline = use_spline
        self.use_swish = use_swish
        self.use_stable = use_stable
        self.pseudo3d = pseudo3d
        self.annealing_factor = [
            1.
        ]  # initialize here; will be changed during training
        self.n_levels = 0

        # decoder parts
        if use_nvae:
            decoder = NDecoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + self.context_dim,
                               output_size=self.decoder_shape)
        else:
            decoder = Decoder(
                num_convolutions=self.num_convolutions,
                filters=self.dec_filters,
                latent_dim=self.latent_dim + self.context_dim,
                upconv=self.use_upconv,
                output_size=self.decoder_shape,
                use_weight_norm=self.use_weight_norm,
                use_spectral_norm=self.use_spectral_norm,
            )

        self._create_decoder(decoder)

        # encoder parts
        if self.use_nvae:
            self.encoder = NEncoder(num_convolutions=self.num_convolutions,
                                    filters=self.enc_filters,
                                    latent_dim=self.latent_dim,
                                    input_size=self.encoder_shape)
        else:
            self.encoder = Encoder(num_convolutions=self.num_convolutions,
                                   filters=self.enc_filters,
                                   latent_dim=self.latent_dim,
                                   input_size=self.encoder_shape,
                                   use_weight_norm=self.use_weight_norm,
                                   use_spectral_norm=self.use_spectral_norm)

        nonlinearity = Swish() if self.use_swish else torch.nn.LeakyReLU(0.1)
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + self.context_dim,
                            self.latent_dim), nonlinearity)

        if self.posterior_components > 1:
            self.latent_encoder = DeepIndepMixtureNormal(
                latent_layers, self.latent_dim, self.latent_dim,
                self.posterior_components)
        else:
            self.latent_encoder = DeepIndepNormal(latent_layers,
                                                  self.latent_dim,
                                                  self.latent_dim)

        if self.prior_components > 1:
            self.z_loc = torch.nn.Parameter(
                torch.randn([self.prior_components, self.latent_dim]))
            self.z_scale = torch.nn.Parameter(
                torch.randn([self.latent_dim]).clamp(min=-1.,
                                                     max=None))  # log scale
            self.register_buffer(
                'z_components',  # don't be bayesian about the mixture components
                ((1 / self.prior_components) * torch.ones(
                    [self.prior_components], requires_grad=False)).log())
        else:
            self.register_buffer(
                'z_loc', torch.zeros([
                    latent_dim,
                ], requires_grad=False))
            self.register_buffer(
                'z_scale', torch.ones([
                    latent_dim,
                ], requires_grad=False))
            self.z_components = None

        # priors
        self.sex_logits = torch.nn.Parameter(torch.zeros([
            1,
        ]))
        self.register_buffer('slice_number_min',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer(
            'slice_number_max', 241. * torch.ones([
                1,
            ], requires_grad=False) + 1.)

        for k in self.required_data - {'sex', 'x', 'slice_number'}:
            self.register_buffer(f'{k}_base_loc',
                                 torch.zeros([
                                     1,
                                 ], requires_grad=False))
            self.register_buffer(f'{k}_base_scale',
                                 torch.ones([
                                     1,
                                 ], requires_grad=False))

        self.register_buffer('x_base_loc',
                             torch.zeros(self.img_shape, requires_grad=False))
        self.register_buffer('x_base_scale',
                             torch.ones(self.img_shape, requires_grad=False))

        for k in self.required_data - {'sex', 'x', 'slice_number'}:
            self.register_buffer(f'{k}_flow_lognorm_loc',
                                 torch.zeros([], requires_grad=False))
            self.register_buffer(f'{k}_flow_lognorm_scale',
                                 torch.ones([], requires_grad=False))

        perm = lambda: torch.randperm(
            self.latent_dim, dtype=torch.long, requires_grad=False)

        self.use_prior_flow = self.n_prior_flows > 0
        self.use_prior_permutations = self.n_prior_flows > 1
        if self.use_prior_permutations:
            for i in range(self.n_prior_flows):
                self.register_buffer(f'prior_flow_permutation_{i}', perm())

        self.use_posterior_flow = self.n_posterior_flows > 0
        self.use_posterior_permutations = self.n_posterior_flows > 1
        if self.use_posterior_permutations:
            for i in range(self.n_posterior_flows):
                self.register_buffer(f'posterior_flow_permutation_{i}', perm())

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(
            loc=self.age_flow_lognorm_loc.item(),
            scale=self.age_flow_lognorm_scale.item())
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # other flows shared components
        self.ventricle_volume_flow_lognorm = AffineTransform(
            loc=self.ventricle_volume_flow_lognorm_loc.item(),
            scale=self.ventricle_volume_flow_lognorm_scale.item(
            ))  # noqa: E501
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])

        self.brain_volume_flow_lognorm = AffineTransform(
            loc=self.brain_volume_flow_lognorm_loc.item(),
            scale=self.brain_volume_flow_lognorm_scale.item())
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])

        self.lesion_volume_flow_lognorm = AffineTransform(
            loc=self.lesion_volume_flow_lognorm_loc.item(),
            scale=self.lesion_volume_flow_lognorm_scale.item())
        self.lesion_volume_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.lesion_volume_flow_constraint_transforms = ComposeTransform([
            self.lesion_volume_flow_lognorm,
            ExpTransform(), self.lesion_volume_flow_eps
        ])

        self.duration_flow_lognorm = AffineTransform(
            loc=self.duration_flow_lognorm_loc.item(),
            scale=self.duration_flow_lognorm_scale.item())
        self.duration_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.duration_flow_constraint_transforms = ComposeTransform([
            self.duration_flow_lognorm,
            ExpTransform(), self.duration_flow_eps
        ])

        self.edss_flow_lognorm = AffineTransform(
            loc=self.edss_flow_lognorm_loc.item(),
            scale=self.edss_flow_lognorm_scale.item())
        self.edss_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.edss_flow_constraint_transforms = ComposeTransform(
            [self.edss_flow_lognorm,
             ExpTransform(), self.edss_flow_eps])

        hidden_dims = (3 * self.latent_dim +
                       1, ) if self.use_autoregressive else (2 *
                                                             self.latent_dim,
                                                             2 *
                                                             self.latent_dim)
        flow_kwargs = dict(hidden_dims=hidden_dims, nonlinearity=nonlinearity)
        if self.use_spline:
            flow_ = spline_autoregressive if self.use_autoregressive else spline_coupling
        else:
            flow_ = affine_autoregressive if self.use_autoregressive else affine_coupling
        if self.use_autoregressive:
            flow_kwargs['stable'] = self.use_stable

        if self.use_prior_permutations:
            self.prior_affine = iterated(
                self.n_prior_flows, batchnorm, self.latent_dim,
                momentum=0.05) if self.use_prior_flow else []
            self.prior_permutations = [
                Permute(getattr(self, f'prior_flow_permutation_{i}'))
                for i in range(self.n_prior_flows)
            ]
            self.prior_flow_components = iterated(
                self.n_prior_flows, flow_, self.latent_dim, **
                flow_kwargs) if self.use_prior_flow else []
            self.prior_flow_transforms = [
                x for c in zip(self.prior_permutations, self.prior_affine,
                               self.prior_flow_components) for x in c
            ]
        else:
            self.prior_affine = []
            self.prior_permutations = []
            self.prior_flow_components = flow_(
                self.latent_dim, **flow_kwargs) if self.use_prior_flow else []
            self.prior_flow_transforms = [self.prior_flow_components]

        if self.use_posterior_permutations:
            self.posterior_affine = iterated(self.n_posterior_flows,
                                             batchnorm,
                                             self.latent_dim,
                                             momentum=0.05)
            self.posterior_permutations = [
                Permute(getattr(self, f'posterior_flow_permutation_{i}'))
                for i in range(self.n_posterior_flows)
            ]
            self.posterior_flow_components = iterated(self.n_posterior_flows,
                                                      flow_, self.latent_dim,
                                                      **flow_kwargs)
            self.posterior_flow_transforms = [
                x
                for c in zip(self.posterior_permutations, self.
                             posterior_affine, self.posterior_flow_components)
                for x in c
            ]
        else:
            self.posterior_affine = []
            self.posterior_permutations = []
            self.posterior_flow_components = flow_(
                self.latent_dim, **
                flow_kwargs) if self.use_posterior_flow else []
            self.posterior_flow_transforms = [self.posterior_flow_components]