def _get_transformed_x_dist(self, latent, ctx=None): x_pred_dist = self.decoder.predict( latent, ctx) # returns a normal dist with mean of the predicted image if self.laplace_likelihood: x_base_dist = Laplace(self.x_base_loc, self.x_base_scale).to_event(3) else: x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event( 3) # 3 dimensions starting from right dep. preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) else: raise ValueError(f'{x_pred_dist} not valid.') return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def generate(self, x, num_particles): z_dist = self.encoder.predict(x) z = z_dist.sample() x_pred_dist = self.decoder.predict(z) x_base_dist = dist.Normal( torch.zeros_like(x, requires_grad=False).view(x.shape[0], -1), torch.ones_like(x, requires_grad=False).view(x.shape[0], -1), ).to_event(1) if 'normal' in self.decoder_output or \ self.decoder_output == 'deepvar' or \ self.decoder_output == 'deepmean': transform = AffineTransform(x_pred_dist.mean, x_pred_dist.stddev, 1) elif self.decoder_output == 'low_rank_mvn': # print(x_pred_dist.loc.shape) # print(x_pred_dist.loc) # print(x_pred_dist.scale_tril.shape) # print(x_pred_dist.scale_tril) transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) else: raise Exception('Unknown decoder output') x_dist = dist.TransformedDistribution(x_base_dist, ComposeTransform([transform])) recons = [] for i in range(num_particles): recon = pyro.sample('x', x_dist).view(x.shape[0], self.shape, 3) recons.append(recon) return torch.stack(recons).mean(0)
def _get_preprocess_transforms(self): alpha = 0.05 num_bits = 8 if self.preprocessing == 'glow': # Map to [-0.5,0.5] a1 = AffineTransform(-0.5, (1. / 2 ** num_bits)) preprocess_transform = ComposeTransform([a1]) elif self.preprocessing == 'realnvp': # Map to [0,1] a1 = AffineTransform(0., (1. / 2 ** num_bits)) # Map into unconstrained space as done in RealNVP a2 = AffineTransform(alpha, (1 - alpha)) s = SigmoidTransform() preprocess_transform = ComposeTransform([a1, a2, s.inv]) else: raise ValueError(f'{self.preprocessing} not valid.') return preprocess_transform
def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]) if event_shape: fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): with pyro.plate("particles", 200000): return pyro.sample("x", fn)
def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): return pyro.sample( "x", dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]))
def model(n_samples=None, scale=2.): with pyro.plate('observations', n_samples): thickness = pyro.sample('thickness', Gamma(10., 5.)) loc = (thickness - 2.5) * 2 transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)]) width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, width
def __init__(self, preprocessing: str = 'realnvp'): super().__init__() self.preprocessing = preprocessing self.register_buffer('thickness_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('thickness_flow_lognorm_scale', torch.ones([], requires_grad=False)) self.register_buffer('intensity_flow_norm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('intensity_flow_norm_scale', torch.ones([], requires_grad=False)) self.thickness_flow_lognorm = AffineTransform( loc=self.thickness_flow_lognorm_loc.item(), scale=self.thickness_flow_lognorm_scale.item()) self.intensity_flow_norm = AffineTransform( loc=self.intensity_flow_norm_loc.item(), scale=self.intensity_flow_norm_scale.item())
def model(n_samples=None, scale=0.5, invert=False): with pyro.plate('observations', n_samples): thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.)) if invert: loc = (thickness - 2) * -2 else: loc = (thickness - 2.5) * 2 transforms = ComposeTransform( [SigmoidTransform(), AffineTransform(64, 191)]) intensity = pyro.sample( 'intensity', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, intensity
def _get_transformed_x_dist(self, latent): x_pred_dist = self.decoder.predict(latent) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def model(self): age, sex, ventricle_volume, brain_volume = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1) x_loc = self.decoder_mean(self.decoder(latent)) x_scale = torch.exp(self.decoder_logstd) theta = self.decoder_affine_param_net(latent).view(-1, 2, 3) with warnings.catch_warnings(): warnings.simplefilter("ignore") grid = nn.functional.affine_grid(theta, x_loc.size()) x_loc_deformed = nn.functional.grid_sample(x_loc, grid) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() x_dist = TransformedDistribution( x_base_dist, ComposeTransform([ AffineTransform(x_loc_deformed, x_scale, 3), preprocess_transform ])) x = pyro.sample('x', x_dist) return x, z, age, sex, ventricle_volume, brain_volume
def __init__(self, **kwargs): super().__init__(**kwargs) # decoder parts self.decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + 3, upconv=self.use_upconv) self.decoder_mean = torch.nn.Conv2d(1, 1, 1) self.decoder_logstd = torch.nn.Parameter( torch.ones([]) * self.logstd_init) self.decoder_affine_param_net = nn.Sequential( nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, 6)) self.decoder_affine_param_net[-1].weight.data.zero_() self.decoder_affine_param_net[-1].bias.data.copy_( torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform(loc=0., scale=1.) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # ventricle_volume flow # TODO: decide on how many things to condition on ventricle_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.ventricle_volume_flow_components = ConditionalAffineTransform( context_nn=ventricle_volume_net, event_dim=0) self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.ventricle_volume_flow_transforms = [ self.ventricle_volume_flow_components, self.ventricle_volume_flow_constraint_transforms ] # brain_volume flow # TODO: decide on how many things to condition on brain_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.brain_volume_flow_components = ConditionalAffineTransform( context_nn=brain_volume_net, event_dim=0) self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_transforms = [ self.brain_volume_flow_components, self.brain_volume_flow_constraint_transforms ] # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim) # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + 3, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim)
def __init__(self, latent_dim: int, logstd_init: float = -5, enc_filters: str = '16,32,64,128', dec_filters: str = '128,64,32,16', num_convolutions: int = 2, use_upconv: bool = False, decoder_type: str = 'fixed_var', decoder_cov_rank: int = 10, **kwargs): super().__init__(**kwargs) self.img_shape = (1, 192 // self.downsample, 192 // self.downsample) if self.downsample > 0 else (1, 192, 192) self.latent_dim = latent_dim self.logstd_init = logstd_init self.enc_filters = tuple( int(f.strip()) for f in enc_filters.split(',')) self.dec_filters = tuple( int(f.strip()) for f in dec_filters.split(',')) self.num_convolutions = num_convolutions self.use_upconv = use_upconv self.decoder_type = decoder_type self.decoder_cov_rank = decoder_cov_rank # decoder parts decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, upconv=self.use_upconv, output_size=self.img_shape) if self.decoder_type == 'fixed_var': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = False torch.nn.init.constant_(self.decoder.logvar_head.bias, self.logstd_init) self.decoder.logvar_head.bias.requires_grad = False elif self.decoder_type == 'learned_var': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = False torch.nn.init.constant_(self.decoder.logvar_head.bias, self.logstd_init) self.decoder.logvar_head.bias.requires_grad = True elif self.decoder_type == 'independent_gaussian': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = True torch.nn.init.normal_(self.decoder.logvar_head.bias, self.logstd_init, 1e-1) self.decoder.logvar_head.bias.requires_grad = True elif self.decoder_type == 'multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape), np.prod(self.img_shape)) elif self.decoder_type == 'sharedvar_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape), np.prod(self.img_shape)) torch.nn.init.zeros_(self.decoder.logdiag_head.weight) self.decoder.logdiag_head.weight.requires_grad = False torch.nn.init.zeros_(self.decoder.lower_head.weight) self.decoder.lower_head.weight.requires_grad = False torch.nn.init.normal_(self.decoder.logdiag_head.bias, self.logstd_init, 1e-1) self.decoder.logdiag_head.bias.requires_grad = True elif self.decoder_type == 'lowrank_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepLowRankMultivariateNormal( seq, np.prod(self.img_shape), np.prod(self.img_shape), decoder_cov_rank) elif self.decoder_type == 'sharedvar_lowrank_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepLowRankMultivariateNormal( seq, np.prod(self.img_shape), np.prod(self.img_shape), decoder_cov_rank) torch.nn.init.zeros_(self.decoder.logdiag_head.weight) self.decoder.logdiag_head.weight.requires_grad = False torch.nn.init.zeros_(self.decoder.factor_head.weight) self.decoder.factor_head.weight.requires_grad = False torch.nn.init.normal_(self.decoder.logdiag_head.bias, self.logstd_init, 1e-1) self.decoder.logdiag_head.bias.requires_grad = True else: raise ValueError('unknown ') # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.img_shape) latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + self.context_dim, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) # priors self.register_buffer('age_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('age_base_scale', torch.ones([ 1, ], requires_grad=False)) self.sex_logits = torch.nn.Parameter(torch.zeros([ 1, ])) self.register_buffer('ventricle_volume_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('ventricle_volume_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('brain_volume_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('brain_volume_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('z_loc', torch.zeros([ latent_dim, ], requires_grad=False)) self.register_buffer('z_scale', torch.ones([ latent_dim, ], requires_grad=False)) self.register_buffer('x_base_loc', torch.zeros(self.img_shape, requires_grad=False)) self.register_buffer('x_base_scale', torch.ones(self.img_shape, requires_grad=False)) self.register_buffer('age_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('age_flow_lognorm_scale', torch.ones([], requires_grad=False)) self.register_buffer('ventricle_volume_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('ventricle_volume_flow_lognorm_scale', torch.ones([], requires_grad=False)) self.register_buffer('brain_volume_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('brain_volume_flow_lognorm_scale', torch.ones([], requires_grad=False)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform( loc=self.age_flow_lognorm_loc.item(), scale=self.age_flow_lognorm_scale.item()) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # other flows shared components self.ventricle_volume_flow_lognorm = AffineTransform( loc=self.ventricle_volume_flow_lognorm_loc.item(), scale=self.ventricle_volume_flow_lognorm_scale.item( )) # noqa: E501 self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_lognorm = AffineTransform( loc=self.brain_volume_flow_lognorm_loc.item(), scale=self.brain_volume_flow_lognorm_scale.item()) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()])
def __neg__(self): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(0, -1)))
def __truediv__(self, x: Union[float, Tensor]): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(0, 1 / x)))
def __rsub__(self, x: Union[float, Tensor]): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(x, -1)))
def __init__(self, latent_dim: int, prior_components: int = 1, posterior_components: int = 1, logstd_init: float = -5, enc_filters: Tuple[int] = (16, 32, 64, 128), dec_filters: Tuple[int] = (128, 64, 32, 16), num_convolutions: int = 3, use_upconv: bool = False, decoder_type: str = 'fixed_var', decoder_cov_rank: int = 10, img_shape: Tuple[int] = (128, 128), use_nvae=False, use_weight_norm=False, use_spectral_norm=False, laplace_likelihood=False, eps=0.1, n_prior_flows=3, n_posterior_flows=3, use_autoregressive=False, use_swish=False, use_spline=False, use_stable=False, pseudo3d=False, head_filters=(16, 16), **kwargs): super().__init__(**kwargs) self.encoder_shape = ((3, ) if pseudo3d else (1, )) + tuple(img_shape) self.decoder_shape = (head_filters[0], ) + tuple(img_shape) self.img_shape = (1, ) + tuple(img_shape) self.latent_dim = latent_dim self.prior_components = prior_components self.posterior_components = posterior_components self.logstd_init = logstd_init self.enc_filters = enc_filters self.dec_filters = dec_filters self.head_filters = head_filters self.num_convolutions = num_convolutions self.use_upconv = use_upconv self.decoder_type = decoder_type self.decoder_cov_rank = decoder_cov_rank self.use_nvae = use_nvae self.use_weight_norm = use_weight_norm self.use_spectral_norm = use_spectral_norm self.laplace_likelihood = laplace_likelihood self.eps = eps self.n_prior_flows = n_prior_flows self.n_posterior_flows = n_posterior_flows self.use_autoregressive = use_autoregressive self.use_spline = use_spline self.use_swish = use_swish self.use_stable = use_stable self.pseudo3d = pseudo3d self.annealing_factor = [ 1. ] # initialize here; will be changed during training self.n_levels = 0 # decoder parts if use_nvae: decoder = NDecoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, output_size=self.decoder_shape) else: decoder = Decoder( num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, upconv=self.use_upconv, output_size=self.decoder_shape, use_weight_norm=self.use_weight_norm, use_spectral_norm=self.use_spectral_norm, ) self._create_decoder(decoder) # encoder parts if self.use_nvae: self.encoder = NEncoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.encoder_shape) else: self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.encoder_shape, use_weight_norm=self.use_weight_norm, use_spectral_norm=self.use_spectral_norm) nonlinearity = Swish() if self.use_swish else torch.nn.LeakyReLU(0.1) latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + self.context_dim, self.latent_dim), nonlinearity) if self.posterior_components > 1: self.latent_encoder = DeepIndepMixtureNormal( latent_layers, self.latent_dim, self.latent_dim, self.posterior_components) else: self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) if self.prior_components > 1: self.z_loc = torch.nn.Parameter( torch.randn([self.prior_components, self.latent_dim])) self.z_scale = torch.nn.Parameter( torch.randn([self.latent_dim]).clamp(min=-1., max=None)) # log scale self.register_buffer( 'z_components', # don't be bayesian about the mixture components ((1 / self.prior_components) * torch.ones( [self.prior_components], requires_grad=False)).log()) else: self.register_buffer( 'z_loc', torch.zeros([ latent_dim, ], requires_grad=False)) self.register_buffer( 'z_scale', torch.ones([ latent_dim, ], requires_grad=False)) self.z_components = None # priors self.sex_logits = torch.nn.Parameter(torch.zeros([ 1, ])) self.register_buffer('slice_number_min', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer( 'slice_number_max', 241. * torch.ones([ 1, ], requires_grad=False) + 1.) for k in self.required_data - {'sex', 'x', 'slice_number'}: self.register_buffer(f'{k}_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer(f'{k}_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('x_base_loc', torch.zeros(self.img_shape, requires_grad=False)) self.register_buffer('x_base_scale', torch.ones(self.img_shape, requires_grad=False)) for k in self.required_data - {'sex', 'x', 'slice_number'}: self.register_buffer(f'{k}_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer(f'{k}_flow_lognorm_scale', torch.ones([], requires_grad=False)) perm = lambda: torch.randperm( self.latent_dim, dtype=torch.long, requires_grad=False) self.use_prior_flow = self.n_prior_flows > 0 self.use_prior_permutations = self.n_prior_flows > 1 if self.use_prior_permutations: for i in range(self.n_prior_flows): self.register_buffer(f'prior_flow_permutation_{i}', perm()) self.use_posterior_flow = self.n_posterior_flows > 0 self.use_posterior_permutations = self.n_posterior_flows > 1 if self.use_posterior_permutations: for i in range(self.n_posterior_flows): self.register_buffer(f'posterior_flow_permutation_{i}', perm()) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform( loc=self.age_flow_lognorm_loc.item(), scale=self.age_flow_lognorm_scale.item()) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # other flows shared components self.ventricle_volume_flow_lognorm = AffineTransform( loc=self.ventricle_volume_flow_lognorm_loc.item(), scale=self.ventricle_volume_flow_lognorm_scale.item( )) # noqa: E501 self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_lognorm = AffineTransform( loc=self.brain_volume_flow_lognorm_loc.item(), scale=self.brain_volume_flow_lognorm_scale.item()) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.lesion_volume_flow_lognorm = AffineTransform( loc=self.lesion_volume_flow_lognorm_loc.item(), scale=self.lesion_volume_flow_lognorm_scale.item()) self.lesion_volume_flow_eps = AffineTransform(loc=-eps, scale=1.) self.lesion_volume_flow_constraint_transforms = ComposeTransform([ self.lesion_volume_flow_lognorm, ExpTransform(), self.lesion_volume_flow_eps ]) self.duration_flow_lognorm = AffineTransform( loc=self.duration_flow_lognorm_loc.item(), scale=self.duration_flow_lognorm_scale.item()) self.duration_flow_eps = AffineTransform(loc=-eps, scale=1.) self.duration_flow_constraint_transforms = ComposeTransform([ self.duration_flow_lognorm, ExpTransform(), self.duration_flow_eps ]) self.edss_flow_lognorm = AffineTransform( loc=self.edss_flow_lognorm_loc.item(), scale=self.edss_flow_lognorm_scale.item()) self.edss_flow_eps = AffineTransform(loc=-eps, scale=1.) self.edss_flow_constraint_transforms = ComposeTransform( [self.edss_flow_lognorm, ExpTransform(), self.edss_flow_eps]) hidden_dims = (3 * self.latent_dim + 1, ) if self.use_autoregressive else (2 * self.latent_dim, 2 * self.latent_dim) flow_kwargs = dict(hidden_dims=hidden_dims, nonlinearity=nonlinearity) if self.use_spline: flow_ = spline_autoregressive if self.use_autoregressive else spline_coupling else: flow_ = affine_autoregressive if self.use_autoregressive else affine_coupling if self.use_autoregressive: flow_kwargs['stable'] = self.use_stable if self.use_prior_permutations: self.prior_affine = iterated( self.n_prior_flows, batchnorm, self.latent_dim, momentum=0.05) if self.use_prior_flow else [] self.prior_permutations = [ Permute(getattr(self, f'prior_flow_permutation_{i}')) for i in range(self.n_prior_flows) ] self.prior_flow_components = iterated( self.n_prior_flows, flow_, self.latent_dim, ** flow_kwargs) if self.use_prior_flow else [] self.prior_flow_transforms = [ x for c in zip(self.prior_permutations, self.prior_affine, self.prior_flow_components) for x in c ] else: self.prior_affine = [] self.prior_permutations = [] self.prior_flow_components = flow_( self.latent_dim, **flow_kwargs) if self.use_prior_flow else [] self.prior_flow_transforms = [self.prior_flow_components] if self.use_posterior_permutations: self.posterior_affine = iterated(self.n_posterior_flows, batchnorm, self.latent_dim, momentum=0.05) self.posterior_permutations = [ Permute(getattr(self, f'posterior_flow_permutation_{i}')) for i in range(self.n_posterior_flows) ] self.posterior_flow_components = iterated(self.n_posterior_flows, flow_, self.latent_dim, **flow_kwargs) self.posterior_flow_transforms = [ x for c in zip(self.posterior_permutations, self. posterior_affine, self.posterior_flow_components) for x in c ] else: self.posterior_affine = [] self.posterior_permutations = [] self.posterior_flow_components = flow_( self.latent_dim, ** flow_kwargs) if self.use_posterior_flow else [] self.posterior_flow_transforms = [self.posterior_flow_components]